From 904374d329a648a9d4e587fdb7fc2c94ebcbb816 Mon Sep 17 00:00:00 2001 From: "suluyan.sly" Date: Mon, 5 Sep 2022 20:58:08 +0800 Subject: [PATCH] [to #42322933] feat: plug inference Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9931748 --- modelscope/metainfo.py | 2 + modelscope/models/nlp/__init__.py | 2 + modelscope/models/nlp/plug/__init__.py | 27 + .../models/nlp/plug/configuration_plug.py | 232 ++++ .../models/nlp/plug/distributed_plug.py | 191 +++ modelscope/models/nlp/plug/modeling_plug.py | 1054 +++++++++++++++++ modelscope/pipelines/base.py | 108 ++ .../nlp/distributed_plug_pipeline.py | 107 ++ modelscope/preprocessors/nlp.py | 3 +- modelscope/trainers/trainer.py | 7 +- modelscope/utils/nlp/distributed.py | 130 ++ modelscope/utils/nlp/load_checkpoint.py | 117 ++ modelscope/utils/torch_utils.py | 21 +- requirements/nlp.txt | 2 + tests/pipelines/test_plug_text_generation.py | 49 + 15 files changed, 2044 insertions(+), 8 deletions(-) create mode 100644 modelscope/models/nlp/plug/__init__.py create mode 100644 modelscope/models/nlp/plug/configuration_plug.py create mode 100644 modelscope/models/nlp/plug/distributed_plug.py create mode 100644 modelscope/models/nlp/plug/modeling_plug.py create mode 100644 modelscope/pipelines/nlp/distributed_plug_pipeline.py create mode 100755 modelscope/utils/nlp/distributed.py create mode 100755 modelscope/utils/nlp/load_checkpoint.py create mode 100644 tests/pipelines/test_plug_text_generation.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3ac2f2df..792bd708 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -55,6 +55,7 @@ class Models(object): lcrf = 'lstm-crf' bart = 'bart' gpt3 = 'gpt3' + plug = 'plug' bert_for_ds = 'bert-for-document-segmentation' # audio models @@ -172,6 +173,7 @@ class Pipelines(object): dialog_state_tracking = 'dialog-state-tracking' zero_shot_classification = 'zero-shot-classification' text_error_correction = 'text-error-correction' + plug_generation = 'plug-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' relation_extraction = 'relation-extraction' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index fd61e40b..9d54834c 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: SingleBackboneTaskModelBase) from .bart_for_text_error_correction import BartForTextErrorCorrection from .gpt3 import GPT3ForTextGeneration + from .plug import PlugForTextGeneration from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering else: @@ -60,6 +61,7 @@ else: ], 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], 'gpt3': ['GPT3ForTextGeneration'], + 'plug': ['PlugForTextGeneration'], 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], } diff --git a/modelscope/models/nlp/plug/__init__.py b/modelscope/models/nlp/plug/__init__.py new file mode 100644 index 00000000..b74258a4 --- /dev/null +++ b/modelscope/models/nlp/plug/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration_plug import PlugNLGConfig + from .modeling_plug import PlugModel + from .distributed_plug import DistributedPlug + from .plug_for_text_generation import PlugForTextGeneration +else: + _import_structure = { + 'configuration_plug': ['PlugNLGConfig'], + 'modeling_plug': ['PlugModel'], + 'distributed_plug': ['DistributedPlug'], + 'plug_for_text_generation': ['PlugForTextGeneration'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/plug/configuration_plug.py b/modelscope/models/nlp/plug/configuration_plug.py new file mode 100644 index 00000000..64807392 --- /dev/null +++ b/modelscope/models/nlp/plug/configuration_plug.py @@ -0,0 +1,232 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import json +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class PlugNLUConfig(PretrainedConfig): + model_type = 'plugNLU' + + def __init__(self, + vocab_size=21504, + original_vocab_size=21128, + hidden_size=8192, + num_hidden_layers=24, + num_attention_heads=128, + intermediate_size=32768, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=3, + initializer_range=0.00707, + deep_init=False, + deepspeed=False, + lr_decay_style='linear', + weight_decay=1e-2, + clip_grad=1.0, + warmup=0.0333, + pre_ln=True, + fp16=True, + fp32_layernorm=True, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-5, + dec_hidden_layers=6, + pruning_method=None, + pruning_mask_init='constant', + pruning_mask_scale=0.0, + pruning_initial_threshold=1.0, + pruning_final_threshold=0.01, + pruning_initial_warmup=1, + pruning_final_warmup=20, + pruning_module='decoder', + pruning_decay_step=50, + pruning_decay_type='exp', + ft_module=None, + attn_separate=False, + LR_weight_rank=8, + LR_mask_rank=8, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.original_vocab_size = original_vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.deep_init = deep_init + self.deepspeed = deepspeed + self.lr_decay_style = lr_decay_style + self.weight_decay = weight_decay + self.clip_grad = clip_grad + self.warmup = warmup + self.pre_ln = pre_ln + self.fp16 = fp16 + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + self.dec_hidden_layers = dec_hidden_layers + self.pruning_method = pruning_method + self.pruning_mask_init = pruning_mask_init + self.pruning_mask_scale = pruning_mask_scale + self.pruning_module = pruning_module + self.pruning_initial_threshold = pruning_initial_threshold + self.pruning_final_threshold = pruning_final_threshold + self.pruning_initial_warmup = pruning_initial_warmup + self.pruning_final_warmup = pruning_final_warmup + self.pruning_decay_step = pruning_decay_step + self.pruning_decay_type = pruning_decay_type + self.ft_module = ft_module + self.attn_separate = attn_separate + self.LR_weight_rank = LR_weight_rank + self.LR_mask_rank = LR_mask_rank + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = PlugNLUConfig() + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def merge_args(self, args): + """merge values a `BertConfig` from a json file of parameters.""" + local_keys = self.__dict__.keys() + for key, value in args.__dict__.items(): + if key in local_keys: + continue + self.__dict__[key] = value + return self + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + +class PlugNLGConfig(PlugNLUConfig): + model_type = 'plugNLG' + + def __init__(self, + vocab_size=21504, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.00707, + deep_init=False, + deepspeed=False, + lr_decay_style='linear', + weight_decay=1e-2, + clip_grad=1.0, + warmup=0.01, + pre_ln=False, + fp16=False, + fp32_layernorm=False, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-12, + dec_hidden_layers=6, + pruning_method=None, + pruning_mask_init='constant', + pruning_mask_scale=0.0, + pruning_initial_threshold=1.0, + pruning_final_threshold=0.01, + pruning_initial_warmup=1, + pruning_final_warmup=20, + pruning_module='decoder', + pruning_decay_step=50, + pruning_decay_type='exp', + ft_module=None, + attn_separate=False, + LR_weight_rank=8, + LR_mask_rank=8, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.deep_init = deep_init + self.deepspeed = deepspeed + self.lr_decay_style = lr_decay_style + self.weight_decay = weight_decay + self.clip_grad = clip_grad + self.warmup = warmup + self.pre_ln = pre_ln + self.fp16 = fp16 + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + self.dec_hidden_layers = dec_hidden_layers + self.pruning_method = pruning_method + self.pruning_mask_init = pruning_mask_init + self.pruning_mask_scale = pruning_mask_scale + self.pruning_module = pruning_module + self.pruning_initial_threshold = pruning_initial_threshold + self.pruning_final_threshold = pruning_final_threshold + self.pruning_initial_warmup = pruning_initial_warmup + self.pruning_final_warmup = pruning_final_warmup + self.pruning_decay_step = pruning_decay_step + self.pruning_decay_type = pruning_decay_type + self.ft_module = ft_module + self.attn_separate = attn_separate + self.LR_weight_rank = LR_weight_rank + self.LR_mask_rank = LR_mask_rank diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py new file mode 100644 index 00000000..2992f595 --- /dev/null +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -0,0 +1,191 @@ +import os +from typing import Dict + +import torch +import torch.nn.functional as F +from megatron import mpu +from megatron.fp16 import FP16_Module +from megatron.utils import print_rank_0 + +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.utils.logger import get_logger +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.nlp.load_checkpoint import pre_load +from modelscope.utils.torch_utils import set_random_seed_mpu +from . import PlugModel +from .configuration_plug import PlugNLGConfig + +logger = get_logger(__name__) + + +class DistributedPlug(TorchModel): + + def __init__(self, model_dir, rank, **kwargs): + super().__init__(model_dir, **kwargs) + self.rank = rank + self.model_cfg = kwargs + self.config = PlugNLGConfig.from_pretrained(model_dir) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + seed = 0 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + self.iteration = 0 + self.dist_model = self.initialize_model(path_load_tag='model') + + def initialize_model(self, path_load_tag='model'): + """Build the model.""" + print_rank_0('Building Plug model. It will take a few minutes ...') + model = PlugModel(self.config) + + if mpu.get_data_parallel_rank() == 0: + logger.info( + ' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()]))) + + if self.config.deepspeed and self.config.fp16: + model.half() + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16: + model = FP16_Module(model) + if self.config.fp32_embedding: + model.module.model.bert.embeddings.word_embeddings.float() + model.module.model.bert.embeddings.position_embeddings.float() + model.module.model.bert.embeddings.token_type_embeddings.float( + ) + if self.config.fp32_tokentypes: + model.module.model.bert.embeddings.token_type_embeddings.float( + ) + if self.config.fp32_layernorm: + for name, _module in model.named_modules(): + if 'LayerNorm' in name: + _module.float() + + load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) + model_dict = model.module.model.state_dict() + for key in load_model: + if key not in model_dict.keys(): + print_rank_0('Skip key: ' + key) + else: + print_rank_0('Loading key: ' + key) + model.module.model.load_state_dict(load_model, strict=False) + return model + + @staticmethod + def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art- + # conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + return logits + + def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): + device = torch.cuda.current_device() + batch_size = input['input_ids'].shape[0] + tokens = input['input_ids'].view(1, -1).contiguous().to(device) + dec_input_ids = input['dec_input_ids'].to(device) + attention_mask = input['attention_mask'].to(device) + self.dist_model.eval() + with torch.no_grad(): + # Only supports batch_size=1 + all_generate_tokens = [] + generate_tokens = [] + counter = 0 + sequence_output = None + vocab_size = self.config.original_vocab_size + sep_token_idx = 102 # index of [SEP] token in BertTokenizer + while counter < out_length: + if counter % 128 == 0 and counter != 0: + # Sliding window + generate_tokens.append(sep_token_idx) + start = (tokens == sep_token_idx).nonzero( + as_tuple=True)[-1] + if start + len(generate_tokens) >= 512: + tokens = torch.cat([ + tokens[:start], + torch.cuda.LongTensor(generate_tokens) + ], -1)[-512:] + else: + tokens[0][start:start + len(generate_tokens + )] = torch.cuda.LongTensor( + generate_tokens) + + attention_mask = (tokens != 0) + dec_input_ids = input['dec_input_ids'].to(device) + generate_tokens = [] + sequence_output = None + + position_ids = torch.full([batch_size, 1], + len(generate_tokens), + dtype=torch.long, + device=device) + _, logits, sequence_output = self.dist_model( + tokens, + None, + attention_mask, + dec_input_ids, + attention_mask, + position_ids, + is_infer=True, + sequence_output=sequence_output, + parallel_output=False) + logits = logits[:, -1, :] + logits = logits / self.model_cfg['temperature'] + logits = self.top_k_logits( + logits, + top_k=self.model_cfg['top_k'], + top_p=self.model_cfg['top_p']) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1) + prev_token = prev[0].item() + if prev_token >= vocab_size: + prev_token = 100 + prev[0] = 100 + if prev_token == 102 and len(all_generate_tokens) > int( + max(1, out_length) * 0.8): + break + if prev_token == 102: + counter += 1 + continue + dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) + generate_tokens.append(prev_token) + all_generate_tokens.append(prev_token) + counter += 1 + + generate_context = [] + for token in all_generate_tokens: + if generate_context and generate_context[ + -1] == 100 and token == 100: + continue + else: + generate_context.append(token) + return {'generate_context': generate_context} diff --git a/modelscope/models/nlp/plug/modeling_plug.py b/modelscope/models/nlp/plug/modeling_plug.py new file mode 100644 index 00000000..9d2bb14f --- /dev/null +++ b/modelscope/models/nlp/plug/modeling_plug.py @@ -0,0 +1,1054 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging +import math +import os + +import torch +import torch.nn.functional as F +from deepspeed.utils.timer import SynchronizedWallClockTimer +from megatron import mpu +from torch import nn + +from modelscope.utils.nlp.distributed import (normal_init_method, + scaled_init_method) +from .configuration_plug import PlugNLGConfig, PlugNLUConfig + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class BertLayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range)) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.fp32_layernorm = config.fp32_layernorm + self.fp32_embedding = config.fp32_embedding + self.fp32_tokentypes = config.fp32_tokentypes + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + if not self.fp32_tokentypes: + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + if self.fp32_embedding and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_embedding: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + else: + embeddings = words_embeddings.float() + position_embeddings.float( + ) + token_type_embeddings.float() + if self.fp32_tokentypes and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_tokentypes: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method, + pruning_method=config.pruning_method if config.pruning_module in [ + 'all', 'encoder', 'encoder_self', 'encoder_selfvo', + 'encoder_selfo' + ] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank) + self.fp32_layernorm = config.fp32_layernorm + if not config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states, + input_tensor, + pruning_threshold=None, + ): + hidden_states = self.dense( + hidden_states, + pruning_threshold=pruning_threshold, + ) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + if self.LayerNorm is not None: + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + else: + hidden_states = ln_input + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config): + super(BertAttention, self).__init__() + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.self = mpu.BertParallelSelfAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout_prob=config.attention_probs_dropout_prob, + output_parallel=True, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range), + separate=config.attn_separate, + pruning_method=config.pruning_method, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + pruning_module=config.pruning_module, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank) + self.output = BertSelfOutput(config) + + def forward( + self, + input_tensor, + attention_mask, + pruning_threshold=None, + ): + if self.LayerNorm is not None: + ln_input = input_tensor + previous_type = input_tensor.type() + if self.fp32_layernorm: + ln_input = input_tensor.float() + ln_output = self.LayerNorm(ln_input) + if self.fp32_layernorm: + ln_output = ln_output.type(previous_type) + self_output = self.self( + ln_output, + attention_mask, + pruning_threshold=pruning_threshold, + ) + else: + self_output = self.self( + input_tensor, + attention_mask, + pruning_threshold=pruning_threshold, + ) + output_pruning_threshold = pruning_threshold + + attention_output = self.output( + self_output, + input_tensor, + pruning_threshold=output_pruning_threshold, + ) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = mpu.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + bias=True, + gather_output=False, + stride=1, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range), + pruning_method=config.pruning_method if config.pruning_module + in ['all', 'encoder', 'encoder_ffn'] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward( + self, + hidden_states, + pruning_threshold=None, + ): + hidden_states = self.dense( + hidden_states, + pruning_threshold=pruning_threshold, + ) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method, + pruning_method=config.pruning_method if config.pruning_module + in ['all', 'encoder', 'encoder_ffn'] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank) + self.fp32_layernorm = config.fp32_layernorm + if not config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states, + input_tensor, + pruning_threshold=None, + ): + hidden_states = self.dense( + hidden_states, + pruning_threshold=pruning_threshold, + ) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + if self.LayerNorm is not None: + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + else: + hidden_states = ln_input + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + + def forward( + self, + hidden_states, + attention_mask, + pruning_threshold=None, + ): + attention_output = self.attention( + hidden_states, attention_mask, pruning_threshold=pruning_threshold) + if self.LayerNorm is not None: + ln_input = attention_output + previous_type = attention_output.type() + if self.fp32_layernorm: + ln_input = attention_output.float() + ln_output = self.LayerNorm(ln_input) + if self.fp32_layernorm: + ln_output = ln_output.type(previous_type) + intermediate_output = self.intermediate( + ln_output, pruning_threshold=pruning_threshold) + else: + intermediate_output = self.intermediate( + attention_output, pruning_threshold=pruning_threshold) + layer_output = self.output( + intermediate_output, + attention_output, + pruning_threshold=pruning_threshold) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + + def forward( + self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + checkpoint_activations=False, + detach_index=-1, + pruning_threshold=None, + ): + all_encoder_layers = [] + + def custom(start, end): + + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer( + x_, inputs[1], pruning_threshold=pruning_threshold) + return x_ + + return custom_forward + + if checkpoint_activations: + layer_idx = 0 + num_layers = len(self.layer) + chunk_length = 1 + while layer_idx < num_layers: + hidden_states = mpu.checkpoint( + custom(layer_idx, layer_idx + chunk_length), hidden_states, + attention_mask * 1) + if detach_index == layer_idx: + hidden_states.detach_() + layer_idx += chunk_length + # decoder layers + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask) + if detach_index == i: + hidden_states.detach_() + if i == len(self.layer) - 1 and self.LayerNorm is not None: + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if not output_all_encoded_layers or checkpoint_activations: + if self.LayerNorm is not None: + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder_weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + self.bias.model_parallel = True + self.fp32_embedding = config.fp32_embedding + self.fp32_layernorm = config.fp32_layernorm + + def convert_to_type(tensor): + if self.fp32_embedding: + return tensor.half() + else: + return tensor + + self.type_converter = convert_to_type + self.converted = False + self.timers = SynchronizedWallClockTimer() + + def forward(self, hidden_states): + if not self.converted: + self.converted = True + if self.fp32_embedding: + self.transform.half() + if self.fp32_layernorm: + self.transform.LayerNorm.float() + hidden_states = self.transform(self.type_converter(hidden_states)) + self.timers('final linear gather').start() + hidden_states = mpu.copy_to_model_parallel_region(hidden_states) + self.timers('final linear gather').stop() + hidden_states = F.linear( + self.type_converter(hidden_states), + self.type_converter(self.decoder_weight), + self.type_converter(self.bias)) + return hidden_states + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 3) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + for p in self.seq_relationship.parameters(): + if p is None: + continue + pooled_output = pooled_output.type_as(p) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, PlugNLUConfig) and not isinstance( + config, PlugNLGConfig): + raise ValueError( + 'Parameter config in `{}(config)` should be an instance of class `BertConfig`. ' + 'To create a model from a Google pretrained model use ' + '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( + self.__class__.__name__, self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(PreTrainedBertModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as + described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + checkpoint_activations=False, + detach_index=-1, + pruning_threshold=None, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations, + detach_index=detach_index, + pruning_threshold=pruning_threshold) + sequence_output = encoded_layers[-1] + for p in self.pooler.parameters(): + if p is None: + continue + sequence_output = sequence_output.type_as(p) + break + + pooled_output = sequence_output[:, 0] + if not output_all_encoded_layers or checkpoint_activations: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class DecodeLayer(nn.Module): + + def __init__(self, config): + super(DecodeLayer, self).__init__() + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + output_layer_init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + + self_pruning_method = config.pruning_method + cross_pruning_method = config.pruning_method + ffn_pruning_method = config.pruning_method + + if config.ft_module is not None: + if 'decoder_self' in config.ft_module: + self_pruning_method = 'finetune' + if 'decoder_cross' in config.ft_module: + cross_pruning_method = 'finetune' + if 'decoder_ffn' in config.ft_module: + ffn_pruning_method = 'finetune' + + self.attention = mpu.GPT2ParallelSelfAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + attention_dropout_prob=config.attention_probs_dropout_prob, + output_dropout_prob=config.hidden_dropout_prob, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + pruning_method=self_pruning_method if config.pruning_module in [ + 'all', 'decoder', 'decoder_self', 'decoder_self+ffn' + ] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank, + ) + + self.cross_attention = mpu.PalmParallelCrossAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + attention_dropout_prob=config.attention_probs_dropout_prob, + output_dropout_prob=config.hidden_dropout_prob, + init_method=init_method, + attn_separate=False, + output_layer_init_method=output_layer_init_method, + pruning_method=cross_pruning_method, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + pruning_module=config.pruning_module, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank, + ) + + self.input_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.post_attention_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.post_cross_attention_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + self.intermediate = mpu.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + gather_output=False, + init_method=init_method, + pruning_method=ffn_pruning_method if config.pruning_module + in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank, + ) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.output = mpu.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + pruning_method=ffn_pruning_method if config.pruning_module + in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, + pruning_mask_init=config.pruning_mask_init, + pruning_mask_scale=config.pruning_mask_scale, + LR_weight_rank=config.LR_weight_rank, + LR_mask_rank=config.LR_mask_rank, + ) + + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.fp32_layernorm = config.fp32_layernorm + + def convert_to_type(tensor): + if self.fp32_layernorm: + return tensor.float() + else: + return tensor + + self.type_converter = convert_to_type + + # def forward(self, hidden_states, enc_attn_mask, dec_attn_mask): + def forward(self, + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + is_infer=False, + pruning_threshold=None): + residual = hidden_states + previous_type = hidden_states.type() + hidden_states = self.input_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.attention( + hidden_states, + dec_attn_mask, + is_infer=is_infer, + pruning_threshold=pruning_threshold) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.cross_attention( + hidden_states, + enc_hidden_states, + enc_attn_mask, + pruning_threshold=pruning_threshold) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_cross_attention_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.intermediate( + hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.intermediate_act_fn(hidden_states) + + hidden_states = self.output( + hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BertDecoder(nn.Module): + + def __init__(self, config): + super(BertDecoder, self).__init__() + self.layer = nn.ModuleList( + [DecodeLayer(config) for _ in range(config.dec_hidden_layers)]) + + self.final_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + checkpoint_activations=False, + output_all_encoded_layers=False, + is_infer=False, + pruning_threshold=None): + + def custom(start, end): + + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer( + x_, + inputs[1], + inputs[2], + dec_attn_mask * 1, + is_infer=is_infer, + pruning_threshold=pruning_threshold) + return x_ + + return custom_forward + + pre_enc_hidden = enc_hidden_states.data + if checkpoint_activations: + layer_idx = 0 + num_layers = len(self.layer) + chunk_length = 1 + while layer_idx < num_layers: + hidden_states = mpu.checkpoint( + custom(layer_idx, layer_idx + chunk_length), hidden_states, + enc_hidden_states, enc_attn_mask * 1) + enc_hidden_states.data = pre_enc_hidden + layer_idx += chunk_length + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + is_infer=is_infer, + pruning_threshold=pruning_threshold) + + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.final_layernorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + + return [hidden_states] + + +class DecodeModel(PreTrainedBertModel): + + def __init__(self, config): + super(DecodeModel, self).__init__(config) + self.decoder = BertDecoder(config) + self.apply(self.init_bert_weights) + + def forward(self, + embeddings, + sequence_output, + decode_input_ids, + position_ids=None, + enc_attn_mask=None, + dec_attn_mask=None, + checkpoint_activations=False, + is_infer=False, + pruning_threshold=None): + extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = embeddings(decode_input_ids) + sequence_output = self.decoder( + embedding_output, + sequence_output, + extended_attention_mask, + dec_attn_mask, + checkpoint_activations=False, + is_infer=is_infer, + pruning_threshold=pruning_threshold) + return sequence_output[-1] + + +class PalmForPreTraining(PreTrainedBertModel): + + def __init__(self, config): + super(PalmForPreTraining, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads( + config, self.bert.embeddings.word_embeddings.weight) + self.decoder = DecodeModel(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + decode_input_ids=None, + position_ids=None, + decode_attention_mask=None, + lm_labels=None, + checkpoint_activations=False, + is_infer=False, + sequence_output=None, + parallel_output=True, + pruning_threshold=None): + if sequence_output is None: + sequence_output, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations, + pruning_threshold=pruning_threshold) + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output) + else: + prediction_scores = None + sequence_output = sequence_output.to( + dtype=next(self.decoder.parameters()).dtype) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + decode_output = self.decoder( + self.bert.embeddings, + sequence_output, + decode_input_ids, + position_ids, + attention_mask, + decode_attention_mask, + checkpoint_activations=checkpoint_activations, + is_infer=is_infer, + pruning_threshold=pruning_threshold) + + transformer_output_parallel = mpu.copy_to_model_parallel_region( + decode_output) + + logits_parallel = F.linear(transformer_output_parallel, + self.bert.embeddings.word_embeddings.weight) + + if parallel_output: + return prediction_scores, logits_parallel + if is_infer: + return prediction_scores, mpu.gather_from_model_parallel_region( + logits_parallel), sequence_output + return prediction_scores, mpu.gather_from_model_parallel_region( + logits_parallel) + + +class PlugModel(torch.nn.Module): + + def __init__(self, config): + super(PlugModel, self).__init__() + self.config = config + self.model = PalmForPreTraining(self.config) + + def forward(self, + input_tokens, + token_type_ids=None, + attention_mask=None, + target_tokens=None, + position_ids=None, + decode_attention_mask=None, + checkpoint_activations=False, + is_infer=False, + sequence_output=None, + parallel_output=True): + return self.model( + input_tokens, + token_type_ids, + attention_mask, + target_tokens, + position_ids, + decode_attention_mask, + checkpoint_activations=checkpoint_activations, + is_infer=is_infer, + sequence_output=sequence_output, + parallel_output=parallel_output) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.model.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + return self.model.load_state_dict(state_dict, strict=strict) diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index d4f9c6bf..5369220f 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -1,7 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import os.path as osp from abc import ABC, abstractmethod +from functools import partial +from multiprocessing import Pool from threading import Lock from typing import Any, Dict, Generator, List, Mapping, Union @@ -15,8 +18,10 @@ from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile from modelscope.utils.device import (create_device, device_placement, verify_device) +from modelscope.utils.hub import read_config, snapshot_download from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import _find_free_port, _is_free_port from .util import is_model, is_official_hub_path if is_torch_available(): @@ -302,3 +307,106 @@ class Pipeline(ABC): output should have the standard output name. """ raise NotImplementedError('postprocess') + + +class DistributedPipeline(Pipeline): + """This pipeline is used to load multi gpu models. + + What will this class do: + 1. Read the global config from the configuration.json + 2. Set the multiprocessing method to spawn + 3. Open a multiprocessing pool of the world_size to instantiate model pieces. + 4. Set the master port and ip + 5. Call _instantiate_one to instantiate one model piece + This method should be implemented by the derived class. + 6. After the forward method is called, do preprocess in main process + and call _forward_one to collect results, and do + post process in main process. + + NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and + store the model handler in the class field. + """ + + def __init__(self, + model: str = None, + preprocessor: Union[Preprocessor, List[Preprocessor]] = None, + auto_collate=True, + **kwargs): + self.preprocessor = preprocessor + self._model_prepare = False + self._model_prepare_lock = Lock() + self._auto_collate = auto_collate + + if os.path.exists(model): + self.model_dir = model + else: + self.model_dir = snapshot_download(model) + self.cfg = read_config(self.model_dir) + self.world_size = self.cfg.model.world_size + self.model_pool = None + self.device_name = 'cpu' + self.device = create_device(self.device_name) + self.has_multiple_models = False + self.framework = self.cfg.framework + if torch.multiprocessing.get_start_method(allow_none=True) is None: + torch.multiprocessing.set_start_method('spawn') + + ranks = list(range(self.world_size)) + self.model_pool = Pool(self.world_size) + master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ + 'master_ip'] + master_port = '29500' if 'master_port' not in kwargs else kwargs[ + 'master_port'] + if not _is_free_port(int(master_port)): + master_port = str(_find_free_port()) + self.model_pool.map( + partial( + self.__class__._instantiate_one, + model_dir=self.model_dir, + master_ip=master_ip, + master_port=master_port, + **self.cfg.model, + **kwargs), ranks) + + def __del__(self): + if hasattr(self, 'model_pool') and self.model_pool is not None: + self.model_pool.terminate() + + def __getstate__(self): + self_dict = self.__dict__.copy() + del self_dict['model_pool'] + del self_dict['preprocessor'] + del self_dict['_model_prepare_lock'] + return self_dict + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + """Instantiate one model piece. + + @param rank: The model rank. + @param model_dir: The model_dir in the node. + @param kwargs: Any extra args. + @return: None. The model handler should be kept in the class field. + """ + pass + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + inputs = { + 'inputs': inputs, + 'forward_params': forward_params, + } + res = self.model_pool.map(self.__class__._forward_one, + [inputs] * self.world_size) + return res[0] + + @classmethod + def _forward_one(cls, inputs): + """Forward the inputs to one model piece. + + Use the model handler kept in the class field to forward. + + @param inputs: The inputs after the preprocessing. + @return: The forward results. + """ + pass diff --git a/modelscope/pipelines/nlp/distributed_plug_pipeline.py b/modelscope/pipelines/nlp/distributed_plug_pipeline.py new file mode 100644 index 00000000..202e6213 --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_plug_pipeline.py @@ -0,0 +1,107 @@ +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.plug import DistributedPlug +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.plug_generation) +class DistributedPlugPipeline(DistributedPipeline): + """This class is used to instantiate the plug model. + """ + + model = None + + def __init__(self, + model, + preprocessor=None, + first_sequence='sentence', + **kwargs): + """Create a plug pipeline instance. + + @param model: The model_id of plug(damo/nlp_plug_text-generation_27B). + The default path to damo/nlp_plug_text-generation_27B can be obtained by function + get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to + this path before calling this class by model_id. + The model can be downloaded from the link on + https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. + After downloading, you should have a plug model structure like this: + /your/path/to/damo/nlp_plug_text-generation_27B + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + @param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will + be used as default. + @param first_sequence: The first_sequence key name if the input format is a dict. + @param kwargs: + sequence_length: The input sequence_length. + """ + if preprocessor is None: + preprocessor = TextGenerationPreprocessor( + model, + first_sequence=first_sequence, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + self.cls_token_id = preprocessor.tokenizer.cls_token_id + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + return cls.model.generate(inputs['inputs'], + **inputs['forward_params']) + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + batch_size = inputs['input_ids'].shape[0] + dec_input_ids = torch.full([batch_size, 1], + self.cls_token_id, + dtype=torch.long) + inputs['dec_input_ids'] = dec_input_ids + res = super().forward(inputs, **forward_params) + return res + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedPlug(model_dir, rank, **kwargs) + cls.model.eval() + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + generate_context = inputs['generate_context'] + generate_context = ''.join( + self.preprocessor.tokenizer.convert_ids_to_tokens( + generate_context)).replace('[UNK]', '“').replace('##', '') + return {OutputKeys.TEXT: generate_context} diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 825611d6..cfb8c9e8 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -164,7 +164,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor): """ model_type = get_model_type(model_dir) - if model_type in (Models.structbert, Models.gpt3, Models.palm): + if model_type in (Models.structbert, Models.gpt3, Models.palm, + Models.plug): from modelscope.models.nlp.structbert import SbertTokenizer return SbertTokenizer.from_pretrained(model_dir, use_fast=False) elif model_type == Models.veco: diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 614b728a..d011dd4a 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -39,7 +39,8 @@ from modelscope.utils.device import create_device, verify_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg -from modelscope.utils.torch_utils import get_dist_info, init_dist +from modelscope.utils.torch_utils import (get_dist_info, init_dist, + set_random_seed) from .base import BaseTrainer from .builder import TRAINERS from .default_config import DEFAULT_CONFIG @@ -922,6 +923,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed): # The seed of each worker equals to # num_worker * rank + worker_id + user_seed worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) - torch.manual_seed(worker_seed) + set_random_seed(worker_seed) diff --git a/modelscope/utils/nlp/distributed.py b/modelscope/utils/nlp/distributed.py new file mode 100755 index 00000000..2b590a10 --- /dev/null +++ b/modelscope/utils/nlp/distributed.py @@ -0,0 +1,130 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.distributed as dist +from megatron import mpu +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable +from torch.nn.modules import Module + +from modelscope.utils.torch_utils import init_dist + + +def initialize_distributed(rank, mpu, world_size, model_parallel_size, + master_ip, master_port): + """Initialize torch.distributed.""" + # Manually set the device ids. + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend='nccl', world_size=8, rank=rank, init_method=init_method) + # Set the model-parallel communicators. + mpu.initialize_model_parallel(model_parallel_size) + + +def normal_init_method(mean, std): + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +def scaled_init_method(mean, std, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = std / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +class DistributedDataParallel(Module): + + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + self.data_parallel_group = mpu.get_data_parallel_group() + src_rank = mpu.get_model_parallel_rank() + for p in self.module.parameters(): + if torch.is_tensor(p): + dist.broadcast(p, src_rank, group=self.data_parallel_group) + + def allreduce_params(reduce_after=True, + no_scale=False, + fp32_allreduce=False): + if (self.needs_reduction): + self.needs_reduction = False + buckets = {} + for name, param in self.module.named_parameters(): + if param.requires_grad and param.grad is not None: + tp = (param.data.type()) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print( + 'WARNING: gloo dist backend for half parameters may be extremely slow.', + 'It is recommended to use the NCCL backend in this case.' + ) + self.warn_on_half = False + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + if fp32_allreduce: + coalesced = coalesced.float() + if not no_scale and not reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + dist.all_reduce(coalesced, group=self.data_parallel_group) + torch.cuda.synchronize() + if not no_scale and reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + for buf, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + self.hook_handles = [] + self.hooks = [] + for param in list(self.module.parameters()): + + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(allreduce_params) + + self.allreduce_params = allreduce_params + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + sd = self.module.state_dict(destination, prefix, keep_vars) + + return sd + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/modelscope/utils/nlp/load_checkpoint.py b/modelscope/utils/nlp/load_checkpoint.py new file mode 100755 index 00000000..6534e18d --- /dev/null +++ b/modelscope/utils/nlp/load_checkpoint.py @@ -0,0 +1,117 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + + +def load_checkpoint(model, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True): + r"""Load training checkpoint + + Arguments: + load_dir: Required. Directory to load the checkpoint from + tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. + load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and + checkpoint match. + load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. + Ex. ADAM's momentum and variance + load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. + Return: + load_path: Path of the loaded checkpoint. None if loading the checkpoint failed + client_state: State dictionary used for loading required training states in the client code. + """ + + load_path, client_states = _load_checkpoint( + model, + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states) + + if load_optimizer_states: + if model.zero_optimization() and load_path is not None: + model._load_zero_checkpoint( + load_dir, tag, load_optimizer_states=load_optimizer_states) + + return load_path, client_states + + +def _get_ckpt_name(mpu, checkpoints_path, tag): + mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() + ckpt_name = os.path.join( + checkpoints_path, str(tag), + 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') + return ckpt_name + + +def pre_load(mpu, load_dir, tag=''): + load_path = _get_ckpt_name(mpu, load_dir, tag) + checkpoint = torch.load( + load_path, map_location=lambda storage, loc: storage) + return checkpoint['module'] + + +def _load_checkpoint(model, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True): + + load_path = model._get_ckpt_name(load_dir, tag) + + if not os.path.exists(load_path): + return None, None + + checkpoint = torch.load( + load_path, map_location=lambda storage, loc: storage) + + model.load_module_state_dict( + state_dict=checkpoint['module'], strict=load_module_strict) + if not model.zero_optimization() and load_optimizer_states: + if model.fp16_enabled(): + model.optimizer.load_state_dict( + checkpoint['optimizer'], + load_optimizer_states=load_optimizer_states) + elif load_optimizer_states: + model.optimizer.load_state_dict(checkpoint['optimizer']) + + if load_lr_scheduler_states and model.lr_scheduler is not None: + model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + model.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] + model.global_steps = checkpoint['global_steps'] + model.global_samples = checkpoint.get( + 'global_samples', model.global_steps * model.train_batch_size()) + model.skipped_steps = checkpoint['skipped_steps'] + model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] + model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] + deepspeed_states = [ + 'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names', + 'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size' + ] + client_state = { + key: value + for key, value in checkpoint.items() if key not in deepspeed_states + } + + return load_path, client_state diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index 45e33c3e..eaa285a2 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -3,16 +3,16 @@ import functools import os import pickle +import random import socket import subprocess import tempfile from typing import Callable, List, Optional, Tuple +import numpy as np import torch import torch.multiprocessing as mp from torch import distributed as dist -from torch._utils import (_flatten_dense_tensors, _take_tensors, - _unflatten_dense_tensors) def _find_free_port() -> str: @@ -49,7 +49,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: def _init_dist_pytorch(backend: str, **kwargs) -> None: # rank = int(os.environ['RANK']) local_rank = int(os.environ['LOCAL_RANK']) - torch.cuda.set_device(local_rank) dist.init_process_group(backend=backend, **kwargs) @@ -180,3 +179,19 @@ def broadcast(inputs, src): dist.broadcast(inputs_tensor, src) return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) + + +def set_random_seed(seed): + if seed is not None and seed >= 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + else: + raise ValueError( + f'Random seed should be positive, current seed is {seed}') + + +def set_random_seed_mpu(seed): + from megatron import mpu + set_random_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) diff --git a/requirements/nlp.txt b/requirements/nlp.txt index ada4fc50..cf0468bb 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,6 +1,8 @@ +deepspeed en_core_web_sm>=2.3.5 fairseq>=0.10.2 jieba>=0.42.1 +megatron_util pai-easynlp # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated diff --git a/tests/pipelines/test_plug_text_generation.py b/tests/pipelines/test_plug_text_generation.py new file mode 100644 index 00000000..90b48efa --- /dev/null +++ b/tests/pipelines/test_plug_text_generation.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + + +class TextPlugGenerationTest(unittest.TestCase): + + def setUp(self) -> None: + # please make sure this local path exists. + self.model_id = 'damo/nlp_plug_text-generation_27B' + self.model_dir = snapshot_download(self.model_id) + self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"' + + @unittest.skip('distributed plug, skipped') + def test_plug(self): + """ The model can be downloaded from the link on + https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. + After downloading, you should have a plug model structure like this: + nlp_plug_text-generation_27B + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + """ + # download model binaries to /model + pipe = pipeline(Tasks.text_generation, model=self.model_id) + print( + f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}' + ) + + +if __name__ == '__main__': + unittest.main()