From 941dbe75cf8c14d27c0877d57c75eaf15f7e7af0 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Mon, 5 Dec 2022 10:01:32 +0800 Subject: [PATCH] [to #42322933] Add GPT-3 tensor parallel finetuning Add GPT-3 tensor parallel finetuning, adjust some distributed codes to make tensor and data parallel compatible. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10949507 --- modelscope/metainfo.py | 1 + modelscope/metrics/text_generation_metric.py | 2 +- modelscope/models/nlp/__init__.py | 4 +- modelscope/models/nlp/gpt3/__init__.py | 2 + .../models/nlp/gpt3/distributed_gpt3.py | 147 +++++++++++++++--- modelscope/models/nlp/gpt3/text_generation.py | 22 ++- modelscope/models/nlp/gpt3/tokenizer.py | 10 +- .../models/nlp/plug/distributed_plug.py | 3 +- .../nlp/distributed_gpt3_pipeline.py | 7 +- .../pipelines/nlp/text_generation_pipeline.py | 6 +- .../nlp/text_generation_preprocessor.py | 33 +++- .../trainers/hooks/logger/text_logger_hook.py | 2 +- modelscope/trainers/nlp/gpt3_trainer.py | 61 ++++++++ modelscope/trainers/trainer.py | 2 +- modelscope/trainers/utils/inference.py | 3 +- modelscope/utils/nlp/distributed.py | 3 +- modelscope/utils/nlp/load_checkpoint.py | 7 +- modelscope/utils/torch_utils.py | 23 +-- tests/trainers/test_finetune_gpt3.py | 129 +++++++++++++++ .../trainers/test_finetune_text_generation.py | 3 +- 20 files changed, 403 insertions(+), 67 deletions(-) create mode 100644 modelscope/trainers/nlp/gpt3_trainer.py create mode 100644 tests/trainers/test_finetune_gpt3.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 12274fb9..2a05035a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -349,6 +349,7 @@ class Trainers(object): nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' text_generation_trainer = 'text-generation-trainer' nlp_plug_trainer = 'nlp-plug-trainer' + gpt3_trainer = 'nlp-gpt3-trainer' # audio trainers speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 3d6e6964..adad871e 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -44,7 +44,7 @@ class TextGenerationMetric(Metric): def remove_useless(string: str) -> str: return string.replace(' ', '').replace('.', '') - return remove_useless(pred) and remove_useless(tgt) + return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0 def evaluate(self): assert self.preds, 'preds in TextGenerationMetric must not be empty!' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index e26bd74e..44aa813a 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from .csanmt import CsanmtForTranslation from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model from .gpt_neo import GPTNeoModel - from .gpt3 import GPT3ForTextGeneration + from .gpt3 import GPT3ForTextGeneration, DistributedGPT3 from .heads import SequenceClassificationHead from .palm_v2 import PalmForTextGeneration from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig @@ -59,7 +59,7 @@ else: 'bart': ['BartForTextErrorCorrection'], 'csanmt': ['CsanmtForTranslation'], 'heads': ['SequenceClassificationHead'], - 'gpt3': ['GPT3ForTextGeneration'], + 'gpt3': ['GPT3ForTextGeneration', 'DistributedGPT3'], 'structbert': [ 'SbertForFaqQuestionAnswering', 'SbertForMaskedLM', diff --git a/modelscope/models/nlp/gpt3/__init__.py b/modelscope/models/nlp/gpt3/__init__.py index 051cc8f2..347e53bf 100644 --- a/modelscope/models/nlp/gpt3/__init__.py +++ b/modelscope/models/nlp/gpt3/__init__.py @@ -8,12 +8,14 @@ if TYPE_CHECKING: from .backbone import GPT3Model from .text_generation import GPT3ForTextGeneration from .tokenizer import JiebaBPETokenizer + from .distributed_gpt3 import DistributedGPT3 else: _import_structure = { 'configuration': ['GPT3Config'], 'backbone': ['GPT3Model'], 'text_generation': ['GPT3ForTextGeneration'], 'tokenizer': ['JiebaBPETokenizer'], + 'distributed_gpt3': ['DistributedGPT3'], } import sys diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index a0091259..424e43b4 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import math +import os +from os import path as osp +from typing import Callable, Dict, List, Optional, Union import torch from megatron import mpu @@ -25,8 +29,14 @@ from torch import nn from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel +from modelscope.fileio import File +from modelscope.metainfo import Models from modelscope.models import TorchModel +from modelscope.models.builder import MODELS from modelscope.models.nlp.gpt3 import GPT3Config +from modelscope.outputs import TextGenerationModelOutput, TokenGeneratorOutput +from modelscope.utils.checkpoint import weights_to_cpu +from modelscope.utils.constant import Tasks from modelscope.utils.nlp.distributed import initialize_distributed from modelscope.utils.nlp.load_checkpoint import pre_load from modelscope.utils.torch_utils import set_random_seed_mpu @@ -435,7 +445,7 @@ class nullcontext: def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor - out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = F.dropout(x + bias, p=prob, training=training) out = residual + out return out @@ -747,11 +757,9 @@ class GPT3Model(PreTrainedModel): config_class = GPT3Config - def __init__(self, config, parallel_output=False): + def __init__(self, config): super().__init__(config) - self.parallel_output = parallel_output - self.language_model = GPT3TransformerLanguageModel( config, init_method_normal(config.init_method_std), scaled_init_method_normal(config.init_method_std, @@ -764,9 +772,7 @@ class GPT3Model(PreTrainedModel): def build_attention_mask_and_position_ids(tokens): seq_length = tokens.size(1) attention_mask = torch.tril( - torch.ones((1, 1, seq_length, seq_length), - dtype=torch.long, - device=tokens.device)) + torch.ones((1, 1, seq_length, seq_length), device=tokens.device)) attention_mask = (attention_mask < 0.5) position_ids = torch.arange( @@ -780,6 +786,7 @@ class GPT3Model(PreTrainedModel): attention_mask=None, position_ids=None, inference_params=None, + labels=None, **kwargs): if attention_mask is None and position_ids is None: attention_mask, position_ids = \ @@ -797,9 +804,18 @@ class GPT3Model(PreTrainedModel): # Gather if needed. output = logits_parallel - if not self.parallel_output: + + if labels is None: output = mpu.gather_from_model_parallel_region(logits_parallel) - return output.transpose(0, 1).contiguous() + # [s b h] => [b s h] + return output.transpose(0, 1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) + # [s b] => [b s] + loss = loss.transpose(0, 1).contiguous() + return loss def modify_logits_for_top_k_filtering(logits, top_k): @@ -911,6 +927,51 @@ class InferenceParams: new_inference_key_memory, new_inference_value_memory) +def split_into_partitions(tensor, num_partitions, partition_dim, stride): + per_partition_size = mpu.utils.divide( + tensor.size(partition_dim), num_partitions) + per_partition_per_stride_size = mpu.utils.divide(per_partition_size, + stride) + partitions_list = torch.split( + tensor, per_partition_per_stride_size, dim=partition_dim) + partitions = [] + for i in range(num_partitions): + partition = torch.cat( + partitions_list[i::num_partitions], dim=partition_dim) + partitions.append(partition) + return partitions + + +def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model, + partitions: int) -> Dict[str, torch.Tensor]: + if partitions == 1: + return state_dict + rank: int = mpu.get_model_parallel_rank() + for name, parameters in model.named_parameters(): + if parameters.shape == state_dict[name].shape: + continue + dim = max(parameters.partition_dim, 0) + stride = parameters.partition_stride + state_dict[name] = split_into_partitions(state_dict[name], partitions, + dim, stride)[rank] + return state_dict + + +def save_checkpoint(model: torch.nn.Module, filename: str) -> None: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + checkpoint = {'module': weights_to_cpu(model.state_dict())} + mp_rank = mpu.get_model_parallel_rank() + filename = osp.join( + osp.dirname(filename), 'model', + 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') + + with io.BytesIO() as f: + torch.save(checkpoint, f) + File.write(f.getvalue(), filename) + + class DistributedGPT3(TorchModel): def __init__(self, @@ -942,33 +1003,63 @@ class DistributedGPT3(TorchModel): model = Float16Module(model, self.config) self.dist_model = model - load_model = pre_load(mpu, model_dir, tag=path_load_tag) + + tensor_ws = mpu.get_model_parallel_world_size() + ckpt_ws = kwargs.pop('checkpoint_model_parallel_size', tensor_ws) + ckpt_rank = mpu.get_model_parallel_rank() * ckpt_ws // tensor_ws + load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag) + load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws) + self.dist_model.load_state_dict(load_model) self.inference_params = None - def forward_step(self, tokens, attention_mask, position_ids): - logits = self.dist_model( + def train(self, mode: bool = True): + if mode: + self.inference_params = None + return super().train(mode) + + def forward(self, + tokens, + attention_mask=None, + position_ids=None, + labels=None, + prompt_length=None): + outputs = self.dist_model( tokens, attention_mask, position_ids, - inference_params=self.inference_params) - self.inference_params.sequence_len_offset += tokens.size(1) - return logits + inference_params=self.inference_params, + labels=labels) + if labels is None: + self.inference_params.sequence_len_offset += tokens.size(1) + return TextGenerationModelOutput(logits=outputs) + else: + loss_mask = torch.ones( + tokens.size(), dtype=torch.float, device=tokens.device) + + losses = outputs.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + return TextGenerationModelOutput(loss=loss) def generate(self, tokens, temperature=1.0, use_eod_token_for_early_termination=True, stop_on_double_eol=False, - stop_on_eol=False): - lengths = torch.tensor([tokens.size(1)], device=tokens.device) + stop_on_eol=False, + **kwargs): + batch_size = tokens.size(0) + lengths = kwargs.pop( + 'prompt_length', + torch.tensor([tokens.size(1)], device=tokens.device)) pads = torch.ones( - 1, self.config.tokens_to_generate, + batch_size, self.config.tokens_to_generate, device=tokens.device).long() * self.config.eod_id tokens = torch.cat((tokens, pads), dim=-1) - batch_size = tokens.size(0) min_prompt_length = lengths.min().item() max_sequence_length = tokens.size(1) max_sequence_length = min(max_sequence_length, @@ -1009,8 +1100,8 @@ class DistributedGPT3(TorchModel): ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. - logits = self.forward_step(tokens2use, attention_mask2use, - positions2use) + logits = self(tokens2use, attention_mask2use, + positions2use).logits # Sample. last_token_logits = logits[:, -1, :] @@ -1054,4 +1145,16 @@ class DistributedGPT3(TorchModel): break tokens = tokens[:, :(context_length + 1)] - return tokens + return TokenGeneratorOutput(sequences=tokens) + + def state_dict(self): + return self.dist_model.state_dict() + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = save_checkpoint, + config: Optional[dict] = None, + **kwargs): + return super().save_pretrained(target_folder, save_checkpoint_names, + save_function, config, **kwargs) diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py index b8b705a5..74335de6 100644 --- a/modelscope/models/nlp/gpt3/text_generation.py +++ b/modelscope/models/nlp/gpt3/text_generation.py @@ -1,10 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os from typing import Dict +from transformers import BertTokenizer + from modelscope.metainfo import Models from modelscope.models.base import Tensor, TorchModel from modelscope.models.builder import MODELS -from modelscope.outputs import OutputKeys +from modelscope.models.nlp.gpt3 import GPT3Model from modelscope.utils.constant import Tasks __all__ = ['GPT3ForTextGeneration'] @@ -21,11 +24,15 @@ class GPT3ForTextGeneration(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - from modelscope.models.nlp.gpt3 import GPT3Model - from transformers import BertTokenizer - - self.model = GPT3Model.from_pretrained(model_dir) - self.tokenizer = BertTokenizer.from_pretrained(model_dir) + # Temporarily compatible with DistributedGPT3 and GPT3Model, + # the base/large model based on GPT3Model will be replaced in the future, + # and GPT3Model will be deprecated + if 'model_parallel_size' in kwargs: + from modelscope.models.nlp import DistributedGPT3 + self.model = DistributedGPT3(model_dir, **kwargs) + else: + self.model = GPT3Model.from_pretrained(model_dir) + self.tokenizer = BertTokenizer.from_pretrained(model_dir) def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model @@ -43,6 +50,9 @@ class GPT3ForTextGeneration(TorchModel): return self.model(**input) def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + if not isinstance(self.model, GPT3Model): + return self.model.generate(**input) + assert 'input_ids' in input, "generate function must accept 'input_ids' key" input_ids = input['input_ids'] if 'attention_mask' in input: diff --git a/modelscope/models/nlp/gpt3/tokenizer.py b/modelscope/models/nlp/gpt3/tokenizer.py index 5780ddbd..ba29891e 100644 --- a/modelscope/models/nlp/gpt3/tokenizer.py +++ b/modelscope/models/nlp/gpt3/tokenizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from tokenizers import Tokenizer @@ -25,9 +27,11 @@ class JiebaBPETokenizer: self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') try: import jieba + import logging + jieba.setLogLevel(logging.INFO) except ImportError: raise ImportError( - 'You need to install rjieba to use JiebaTokenizer. ' + 'You need to install jieba to use JiebaTokenizer. ' 'See https://pypi.org/project/rjieba/ for installation.') self.jieba = jieba self.new_line = self.vocab['\n'] @@ -49,7 +53,7 @@ class JiebaBPETokenizer: inv_vocab[val] = key return inv_vocab - def tokenize(self, text, is_code=False): + def tokenize(self, text: str, is_code: bool = False) -> List[int]: """ """ if not is_code: @@ -61,7 +65,7 @@ class JiebaBPETokenizer: text, is_pretokenized=False, add_special_tokens=True).ids def detokenize(self, token_ids): - text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + text = self.tokenizer.decode(token_ids, skip_special_tokens=True) return text @property diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py index e8c04de3..23b83078 100644 --- a/modelscope/models/nlp/plug/distributed_plug.py +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -110,7 +110,8 @@ class DistributedPlug(TorchModel): if 'LayerNorm' in name: _module.float() - load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) + load_model = pre_load( + mpu.get_model_parallel_rank(), self.model_dir, tag=path_load_tag) model_dict = model.module.model.state_dict() for key in load_model: if key not in model_dict.keys(): diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py index 216d5302..e098823b 100644 --- a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py +++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py @@ -5,7 +5,7 @@ from typing import Any, Dict import torch from modelscope.metainfo import Pipelines -from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 +from modelscope.models.nlp import DistributedGPT3 from modelscope.pipelines.base import DistributedPipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import TextGenerationJiebaPreprocessor @@ -30,7 +30,7 @@ class DistributedGPT3Pipeline(DistributedPipeline): Extra kwargs passed into the preprocessor's constructor. """ if preprocessor is None: - preprocessor = TextGenerationJiebaPreprocessor(model, **kwargs) + preprocessor = TextGenerationJiebaPreprocessor(model) super().__init__(model, preprocessor=preprocessor, **kwargs) assert hasattr(preprocessor, 'tokenizer') @@ -58,5 +58,6 @@ class DistributedGPT3Pipeline(DistributedPipeline): from modelscope.outputs import OutputKeys return { OutputKeys.TEXT: - self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) + self.preprocessor.tokenizer.detokenize( + inputs.sequences[0].tolist()) } diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 566ca359..16e871ab 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -30,7 +30,6 @@ class TextGenerationPipeline(Pipeline): device: str = 'gpu', auto_collate=True, first_sequence='sentence', - sequence_length=128, **kwargs): """Use `model` and `preprocessor` to create a generation pipeline for prediction. @@ -63,10 +62,7 @@ class TextGenerationPipeline(Pipeline): if preprocessor is None: self.preprocessor = Preprocessor.from_pretrained( - self.model.model_dir, - first_sequence=first_sequence, - sequence_length=sequence_length, - **kwargs) + self.model.model_dir, first_sequence=first_sequence, **kwargs) self.model.eval() self.postprocessor = kwargs.pop('postprocessor', None) if self.postprocessor is None and hasattr(self.model, 'model_dir'): diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py index e0f8d943..71665fab 100644 --- a/modelscope/preprocessors/nlp/text_generation_preprocessor.py +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -192,7 +192,9 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): model_dir: str, mode: str = ModeKeys.INFERENCE, src_txt='src_txt', - tgt_txt=None): + tgt_txt=None, + sequence_length: int = 128, + use_fast=None): from modelscope.models.nlp.gpt3 import JiebaBPETokenizer super().__init__(mode, src_txt, tgt_txt) if self.tgt_txt is not None: @@ -202,6 +204,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): self.src_txt = src_txt self.tokenizer = JiebaBPETokenizer( osp.join(model_dir, 'tokenizer.json')) + self.max_length = sequence_length def decode(self, tokens, **kwargs): """Decode the tokens to real text. @@ -214,6 +217,14 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): """ return self.tokenizer.detokenize(tokens) + def _truncate(self, array: np.ndarray) -> np.ndarray: + if len(array) < self.max_length: + return np.pad( + array, (0, self.max_length - len(array)), + constant_values=self.tokenizer.eod) + else: + return array[:self.max_length] + def _tokenize_text(self, sequence1, sequence2=None, **kwargs): """Tokenize the text. @@ -224,10 +235,22 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): Returns: The encoded sequence. """ - return { - 'input_ids': - torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0) - } + if self.mode == ModeKeys.INFERENCE: + return { + 'input_ids': + torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0) + } + else: + tokens = self.tokenizer.tokenize(sequence1) + prompt_length = min(len(tokens), self.max_length - 1) + if sequence2 is not None: + tokens += self.tokenizer.tokenize(sequence2) + tokens = self._truncate(np.array(tokens)) + return { + 'tokens': tokens[:-1], + 'labels': tokens[1:], + 'prompt_length': prompt_length, + } @PREPROCESSORS.register_module( diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index 223867b2..eb22d03c 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -74,7 +74,7 @@ class TextLoggerHook(LoggerHook): self._dump_log(trainer.meta) def _get_max_memory(self, trainer): - device = getattr(trainer.model, 'output_device', None) + device = torch.cuda.current_device() mem = torch.cuda.max_memory_allocated(device=device) mem_mb = torch.tensor([mem / (1024 * 1024)], dtype=torch.int, diff --git a/modelscope/trainers/nlp/gpt3_trainer.py b/modelscope/trainers/nlp/gpt3_trainer.py new file mode 100644 index 00000000..51e7ba1e --- /dev/null +++ b/modelscope/trainers/nlp/gpt3_trainer.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from collections.abc import Mapping +from typing import List + +import torch +from megatron import mpu + +from modelscope.metainfo import Trainers +from modelscope.models import TorchModel +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.utils.config import Config +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.gpt3_trainer) +class GPT3Trainer(NlpEpochBasedTrainer): + + def rebuild_config(self, cfg: Config): + super().rebuild_config(cfg) + cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1)) + cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1') + cfg.model.master_port = os.environ.get('MASTER_PORT', '29500') + return cfg + + def train_step(self, model: TorchModel, inputs: Mapping): + keys = list(inputs.keys()) + datatype = torch.int64 + inputs = mpu.broadcast_data(keys, inputs, datatype) + return super().train_step(model, inputs) + + def _decode(self, tokens): + tokenizer = self.eval_preprocessor.tokenizer + return tokenizer.detokenize(tokens.tolist()) + + def evaluation_step(self, data): + model = self.model.module if self._dist else self.model + model.eval() + + with torch.no_grad(): + if isinstance( + data, + Mapping) and not func_receive_dict_inputs(model.generate): + result = model.generate(**data) + else: + result = model.generate(data) + + prompt_length: List[int] = data['prompt_length'] + result['preds'] = [ + self._decode(seq[skip_len:]) + for seq, skip_len in zip(result['sequences'], prompt_length) + ] + data['tgts'] = [ + self._decode(seq[skip_len - 1:]) + for seq, skip_len in zip(data['labels'], prompt_length) + ] + assert len(result['preds']) == len(data['tgts']) + + return result diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 649cb96a..fbfcf96c 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer): device_name: The final device name. """ device_name = device if device is not None else 'gpu' - if self._dist: + if dist.is_initialized(): local_rank = get_local_rank() device_name = f'cuda:{local_rank}' diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index 4ea34d59..631d011e 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -137,7 +137,8 @@ def multi_gpu_test(trainer, else: batch_size = len(data) if i >= (data_len // world_size) - 1: - total_samples = torch.LongTensor([batch_size]).to(model.device) + total_samples = torch.LongTensor([batch_size + ]).to(trainer.model.device) dist.all_reduce(total_samples, op=dist.reduce_op.SUM) total_samples = total_samples.item() else: diff --git a/modelscope/utils/nlp/distributed.py b/modelscope/utils/nlp/distributed.py index 53332c0f..3dcb5f71 100755 --- a/modelscope/utils/nlp/distributed.py +++ b/modelscope/utils/nlp/distributed.py @@ -14,6 +14,7 @@ # limitations under the License. import math +import os import torch import torch.distributed as dist @@ -36,7 +37,7 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size, init_method += master_ip + ':' + master_port torch.distributed.init_process_group( backend='nccl', - world_size=world_size, + world_size=int(os.getenv('WORLD_SIZE', world_size)), rank=rank, init_method=init_method) # Set the model-parallel communicators. diff --git a/modelscope/utils/nlp/load_checkpoint.py b/modelscope/utils/nlp/load_checkpoint.py index 6534e18d..920097a0 100755 --- a/modelscope/utils/nlp/load_checkpoint.py +++ b/modelscope/utils/nlp/load_checkpoint.py @@ -55,16 +55,15 @@ def load_checkpoint(model, return load_path, client_states -def _get_ckpt_name(mpu, checkpoints_path, tag): - mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() +def _get_ckpt_name(mp_rank, checkpoints_path, tag): ckpt_name = os.path.join( checkpoints_path, str(tag), 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') return ckpt_name -def pre_load(mpu, load_dir, tag=''): - load_path = _get_ckpt_name(mpu, load_dir, tag) +def pre_load(mp_rank, load_dir, tag=''): + load_path = _get_ckpt_name(mp_rank, load_dir, tag) checkpoint = torch.load( load_path, map_location=lambda storage, loc: storage) return checkpoint['module'] diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index 74d9bb7b..e8c21d86 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -107,8 +107,14 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: def get_dist_info() -> Tuple[int, int]: if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() + try: + from megatron import mpu + assert mpu.model_parallel_is_initialized() + rank = mpu.get_data_parallel_rank() + world_size = mpu.get_data_parallel_world_size() + except (ImportError, AssertionError): + rank = dist.get_rank() + world_size = dist.get_world_size() else: rank = 0 world_size = 1 @@ -120,16 +126,14 @@ def get_local_rank(): def is_master(): - rank, _ = get_dist_info() - return rank == 0 + return dist.get_rank() == 0 if dist.is_initialized() else True def master_only(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): - rank, _ = get_dist_info() - if rank == 0: + if is_master(): return func(*args, **kwargs) return wrapper @@ -138,12 +142,11 @@ def master_only(func: Callable) -> Callable: def make_tmp_dir(): """Make sure each rank has the same temporary directory on the distributed mode. """ - rank, world_size = get_dist_info() - if world_size <= 1: + if not dist.is_initialized(): return tempfile.mkdtemp() tmpdir = None - if rank == 0: + if is_master(): tmpdir = tempfile.mkdtemp() dist.barrier() @@ -162,7 +165,7 @@ def broadcast(inputs, src): Returns: Each rank returns the same value as src. """ - rank, _ = get_dist_info() + rank = dist.get_rank() shape_tensor = torch.tensor([0], device='cuda') if rank == src: diff --git a/tests/trainers/test_finetune_gpt3.py b/tests/trainers/test_finetune_gpt3.py new file mode 100644 index 00000000..e2110cfa --- /dev/null +++ b/tests/trainers/test_finetune_gpt3.py @@ -0,0 +1,129 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer + + +class TestFinetuneTextGeneration(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skip + def test_finetune_poetry(self): + dataset_dict = MsDataset.load('chinese-poetry-collection') + train_dataset = dataset_dict['train'].to_hf_dataset().rename_columns( + {'text1': 'src_txt'}) + eval_dataset = dataset_dict['test'].to_hf_dataset().rename_columns( + {'text1': 'src_txt'}) + max_epochs = 10 + tmp_dir = './gpt3_poetry' + + num_warmup_steps = 100 + + def noam_lambda(current_step: int): + current_step += 1 + return min(current_step**(-0.5), + current_step * num_warmup_steps**(-1.5)) + + def cfg_modify_fn(cfg): + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': noam_lambda, + 'options': { + 'by_epoch': False + } + } + cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4} + cfg.train.dataloader = { + 'batch_size_per_gpu': 16, + 'workers_per_gpu': 1 + } + return cfg + + kwargs = dict( + model='damo/nlp_gpt3_text-generation_1.3B', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + max_epochs=max_epochs, + work_dir=tmp_dir, + cfg_modify_fn=cfg_modify_fn) + + # Construct trainer and train + trainer = build_trainer( + name=Trainers.gpt3_trainer, default_args=kwargs) + trainer.train() + + @unittest.skip + def test_finetune_dureader(self): + # DuReader_robust-QG is an example data set, + # users can also use their own data set for training + dataset_dict = MsDataset.load('DuReader_robust-QG') + + train_dataset = dataset_dict['train'].to_hf_dataset() \ + .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \ + .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '') + '\n'}) + eval_dataset = dataset_dict['validation'].to_hf_dataset() \ + .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \ + .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '') + '\n'}) + + max_epochs = 10 + tmp_dir = './gpt3_dureader' + + num_warmup_steps = 200 + + def noam_lambda(current_step: int): + current_step += 1 + return min(current_step**(-0.5), + current_step * num_warmup_steps**(-1.5)) + + def cfg_modify_fn(cfg): + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': noam_lambda, + 'options': { + 'by_epoch': False + } + } + cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4} + cfg.train.dataloader = { + 'batch_size_per_gpu': 16, + 'workers_per_gpu': 1 + } + cfg.train.hooks.append({ + 'type': 'EvaluationHook', + 'by_epoch': True, + 'interval': 1 + }) + cfg.preprocessor.sequence_length = 512 + cfg.model.checkpoint_model_parallel_size = 1 + return cfg + + kwargs = dict( + model='damo/nlp_gpt3_text-generation_1.3B', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + max_epochs=max_epochs, + work_dir=tmp_dir, + cfg_modify_fn=cfg_modify_fn) + + # Construct trainer and train + trainer = build_trainer( + name=Trainers.gpt3_trainer, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_text_generation.py b/tests/trainers/test_finetune_text_generation.py index 59bef51c..9981e228 100644 --- a/tests/trainers/test_finetune_text_generation.py +++ b/tests/trainers/test_finetune_text_generation.py @@ -80,7 +80,8 @@ class TestFinetuneTextGeneration(unittest.TestCase): max_epochs=self.max_epochs, work_dir=self.tmp_dir) - trainer = build_trainer(default_args=kwargs) + trainer = build_trainer( + name=Trainers.text_generation_trainer, default_args=kwargs) trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files)