Add GPT-3 tensor parallel finetuning, adjust some distributed codes to make tensor and data parallel compatible. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10949507master^2
| @@ -349,6 +349,7 @@ class Trainers(object): | |||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||
| text_generation_trainer = 'text-generation-trainer' | |||
| nlp_plug_trainer = 'nlp-plug-trainer' | |||
| gpt3_trainer = 'nlp-gpt3-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| @@ -44,7 +44,7 @@ class TextGenerationMetric(Metric): | |||
| def remove_useless(string: str) -> str: | |||
| return string.replace(' ', '').replace('.', '') | |||
| return remove_useless(pred) and remove_useless(tgt) | |||
| return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0 | |||
| def evaluate(self): | |||
| assert self.preds, 'preds in TextGenerationMetric must not be empty!' | |||
| @@ -18,7 +18,7 @@ if TYPE_CHECKING: | |||
| from .csanmt import CsanmtForTranslation | |||
| from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model | |||
| from .gpt_neo import GPTNeoModel | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .gpt3 import GPT3ForTextGeneration, DistributedGPT3 | |||
| from .heads import SequenceClassificationHead | |||
| from .palm_v2 import PalmForTextGeneration | |||
| from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig | |||
| @@ -59,7 +59,7 @@ else: | |||
| 'bart': ['BartForTextErrorCorrection'], | |||
| 'csanmt': ['CsanmtForTranslation'], | |||
| 'heads': ['SequenceClassificationHead'], | |||
| 'gpt3': ['GPT3ForTextGeneration'], | |||
| 'gpt3': ['GPT3ForTextGeneration', 'DistributedGPT3'], | |||
| 'structbert': [ | |||
| 'SbertForFaqQuestionAnswering', | |||
| 'SbertForMaskedLM', | |||
| @@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||
| from .backbone import GPT3Model | |||
| from .text_generation import GPT3ForTextGeneration | |||
| from .tokenizer import JiebaBPETokenizer | |||
| from .distributed_gpt3 import DistributedGPT3 | |||
| else: | |||
| _import_structure = { | |||
| 'configuration': ['GPT3Config'], | |||
| 'backbone': ['GPT3Model'], | |||
| 'text_generation': ['GPT3ForTextGeneration'], | |||
| 'tokenizer': ['JiebaBPETokenizer'], | |||
| 'distributed_gpt3': ['DistributedGPT3'], | |||
| } | |||
| import sys | |||
| @@ -13,7 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import io | |||
| import math | |||
| import os | |||
| from os import path as osp | |||
| from typing import Callable, Dict, List, Optional, Union | |||
| import torch | |||
| from megatron import mpu | |||
| @@ -25,8 +29,14 @@ from torch import nn | |||
| from torch.nn import functional as F | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.gpt3 import GPT3Config | |||
| from modelscope.outputs import TextGenerationModelOutput, TokenGeneratorOutput | |||
| from modelscope.utils.checkpoint import weights_to_cpu | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.nlp.distributed import initialize_distributed | |||
| from modelscope.utils.nlp.load_checkpoint import pre_load | |||
| from modelscope.utils.torch_utils import set_random_seed_mpu | |||
| @@ -435,7 +445,7 @@ class nullcontext: | |||
| def bias_dropout_add(x, bias, residual, prob, training): | |||
| # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor | |||
| out = torch.nn.functional.dropout(x + bias, p=prob, training=training) | |||
| out = F.dropout(x + bias, p=prob, training=training) | |||
| out = residual + out | |||
| return out | |||
| @@ -747,11 +757,9 @@ class GPT3Model(PreTrainedModel): | |||
| config_class = GPT3Config | |||
| def __init__(self, config, parallel_output=False): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.parallel_output = parallel_output | |||
| self.language_model = GPT3TransformerLanguageModel( | |||
| config, init_method_normal(config.init_method_std), | |||
| scaled_init_method_normal(config.init_method_std, | |||
| @@ -764,9 +772,7 @@ class GPT3Model(PreTrainedModel): | |||
| def build_attention_mask_and_position_ids(tokens): | |||
| seq_length = tokens.size(1) | |||
| attention_mask = torch.tril( | |||
| torch.ones((1, 1, seq_length, seq_length), | |||
| dtype=torch.long, | |||
| device=tokens.device)) | |||
| torch.ones((1, 1, seq_length, seq_length), device=tokens.device)) | |||
| attention_mask = (attention_mask < 0.5) | |||
| position_ids = torch.arange( | |||
| @@ -780,6 +786,7 @@ class GPT3Model(PreTrainedModel): | |||
| attention_mask=None, | |||
| position_ids=None, | |||
| inference_params=None, | |||
| labels=None, | |||
| **kwargs): | |||
| if attention_mask is None and position_ids is None: | |||
| attention_mask, position_ids = \ | |||
| @@ -797,9 +804,18 @@ class GPT3Model(PreTrainedModel): | |||
| # Gather if needed. | |||
| output = logits_parallel | |||
| if not self.parallel_output: | |||
| if labels is None: | |||
| output = mpu.gather_from_model_parallel_region(logits_parallel) | |||
| return output.transpose(0, 1).contiguous() | |||
| # [s b h] => [b s h] | |||
| return output.transpose(0, 1).contiguous() | |||
| else: | |||
| # [b s] => [s b] | |||
| labels = labels.transpose(0, 1).contiguous() | |||
| loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) | |||
| # [s b] => [b s] | |||
| loss = loss.transpose(0, 1).contiguous() | |||
| return loss | |||
| def modify_logits_for_top_k_filtering(logits, top_k): | |||
| @@ -911,6 +927,51 @@ class InferenceParams: | |||
| new_inference_key_memory, new_inference_value_memory) | |||
| def split_into_partitions(tensor, num_partitions, partition_dim, stride): | |||
| per_partition_size = mpu.utils.divide( | |||
| tensor.size(partition_dim), num_partitions) | |||
| per_partition_per_stride_size = mpu.utils.divide(per_partition_size, | |||
| stride) | |||
| partitions_list = torch.split( | |||
| tensor, per_partition_per_stride_size, dim=partition_dim) | |||
| partitions = [] | |||
| for i in range(num_partitions): | |||
| partition = torch.cat( | |||
| partitions_list[i::num_partitions], dim=partition_dim) | |||
| partitions.append(partition) | |||
| return partitions | |||
| def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model, | |||
| partitions: int) -> Dict[str, torch.Tensor]: | |||
| if partitions == 1: | |||
| return state_dict | |||
| rank: int = mpu.get_model_parallel_rank() | |||
| for name, parameters in model.named_parameters(): | |||
| if parameters.shape == state_dict[name].shape: | |||
| continue | |||
| dim = max(parameters.partition_dim, 0) | |||
| stride = parameters.partition_stride | |||
| state_dict[name] = split_into_partitions(state_dict[name], partitions, | |||
| dim, stride)[rank] | |||
| return state_dict | |||
| def save_checkpoint(model: torch.nn.Module, filename: str) -> None: | |||
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
| model = model.module | |||
| checkpoint = {'module': weights_to_cpu(model.state_dict())} | |||
| mp_rank = mpu.get_model_parallel_rank() | |||
| filename = osp.join( | |||
| osp.dirname(filename), 'model', | |||
| 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') | |||
| with io.BytesIO() as f: | |||
| torch.save(checkpoint, f) | |||
| File.write(f.getvalue(), filename) | |||
| class DistributedGPT3(TorchModel): | |||
| def __init__(self, | |||
| @@ -942,33 +1003,63 @@ class DistributedGPT3(TorchModel): | |||
| model = Float16Module(model, self.config) | |||
| self.dist_model = model | |||
| load_model = pre_load(mpu, model_dir, tag=path_load_tag) | |||
| tensor_ws = mpu.get_model_parallel_world_size() | |||
| ckpt_ws = kwargs.pop('checkpoint_model_parallel_size', tensor_ws) | |||
| ckpt_rank = mpu.get_model_parallel_rank() * ckpt_ws // tensor_ws | |||
| load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag) | |||
| load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws) | |||
| self.dist_model.load_state_dict(load_model) | |||
| self.inference_params = None | |||
| def forward_step(self, tokens, attention_mask, position_ids): | |||
| logits = self.dist_model( | |||
| def train(self, mode: bool = True): | |||
| if mode: | |||
| self.inference_params = None | |||
| return super().train(mode) | |||
| def forward(self, | |||
| tokens, | |||
| attention_mask=None, | |||
| position_ids=None, | |||
| labels=None, | |||
| prompt_length=None): | |||
| outputs = self.dist_model( | |||
| tokens, | |||
| attention_mask, | |||
| position_ids, | |||
| inference_params=self.inference_params) | |||
| self.inference_params.sequence_len_offset += tokens.size(1) | |||
| return logits | |||
| inference_params=self.inference_params, | |||
| labels=labels) | |||
| if labels is None: | |||
| self.inference_params.sequence_len_offset += tokens.size(1) | |||
| return TextGenerationModelOutput(logits=outputs) | |||
| else: | |||
| loss_mask = torch.ones( | |||
| tokens.size(), dtype=torch.float, device=tokens.device) | |||
| losses = outputs.float() | |||
| loss_mask = loss_mask.view(-1).float() | |||
| loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() | |||
| return TextGenerationModelOutput(loss=loss) | |||
| def generate(self, | |||
| tokens, | |||
| temperature=1.0, | |||
| use_eod_token_for_early_termination=True, | |||
| stop_on_double_eol=False, | |||
| stop_on_eol=False): | |||
| lengths = torch.tensor([tokens.size(1)], device=tokens.device) | |||
| stop_on_eol=False, | |||
| **kwargs): | |||
| batch_size = tokens.size(0) | |||
| lengths = kwargs.pop( | |||
| 'prompt_length', | |||
| torch.tensor([tokens.size(1)], device=tokens.device)) | |||
| pads = torch.ones( | |||
| 1, self.config.tokens_to_generate, | |||
| batch_size, self.config.tokens_to_generate, | |||
| device=tokens.device).long() * self.config.eod_id | |||
| tokens = torch.cat((tokens, pads), dim=-1) | |||
| batch_size = tokens.size(0) | |||
| min_prompt_length = lengths.min().item() | |||
| max_sequence_length = tokens.size(1) | |||
| max_sequence_length = min(max_sequence_length, | |||
| @@ -1009,8 +1100,8 @@ class DistributedGPT3(TorchModel): | |||
| ..., prev_context_length:context_length, :context_length] | |||
| # logits will be meanigful only in the last pipeline stage. | |||
| logits = self.forward_step(tokens2use, attention_mask2use, | |||
| positions2use) | |||
| logits = self(tokens2use, attention_mask2use, | |||
| positions2use).logits | |||
| # Sample. | |||
| last_token_logits = logits[:, -1, :] | |||
| @@ -1054,4 +1145,16 @@ class DistributedGPT3(TorchModel): | |||
| break | |||
| tokens = tokens[:, :(context_length + 1)] | |||
| return tokens | |||
| return TokenGeneratorOutput(sequences=tokens) | |||
| def state_dict(self): | |||
| return self.dist_model.state_dict() | |||
| def save_pretrained(self, | |||
| target_folder: Union[str, os.PathLike], | |||
| save_checkpoint_names: Union[str, List[str]] = None, | |||
| save_function: Callable = save_checkpoint, | |||
| config: Optional[dict] = None, | |||
| **kwargs): | |||
| return super().save_pretrained(target_folder, save_checkpoint_names, | |||
| save_function, config, **kwargs) | |||
| @@ -1,10 +1,13 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Dict | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.models.nlp.gpt3 import GPT3Model | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['GPT3ForTextGeneration'] | |||
| @@ -21,11 +24,15 @@ class GPT3ForTextGeneration(TorchModel): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from modelscope.models.nlp.gpt3 import GPT3Model | |||
| from transformers import BertTokenizer | |||
| self.model = GPT3Model.from_pretrained(model_dir) | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
| # Temporarily compatible with DistributedGPT3 and GPT3Model, | |||
| # the base/large model based on GPT3Model will be replaced in the future, | |||
| # and GPT3Model will be deprecated | |||
| if 'model_parallel_size' in kwargs: | |||
| from modelscope.models.nlp import DistributedGPT3 | |||
| self.model = DistributedGPT3(model_dir, **kwargs) | |||
| else: | |||
| self.model = GPT3Model.from_pretrained(model_dir) | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| @@ -43,6 +50,9 @@ class GPT3ForTextGeneration(TorchModel): | |||
| return self.model(**input) | |||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| if not isinstance(self.model, GPT3Model): | |||
| return self.model.generate(**input) | |||
| assert 'input_ids' in input, "generate function must accept 'input_ids' key" | |||
| input_ids = input['input_ids'] | |||
| if 'attention_mask' in input: | |||
| @@ -12,6 +12,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import List | |||
| from tokenizers import Tokenizer | |||
| @@ -25,9 +27,11 @@ class JiebaBPETokenizer: | |||
| self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') | |||
| try: | |||
| import jieba | |||
| import logging | |||
| jieba.setLogLevel(logging.INFO) | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'You need to install rjieba to use JiebaTokenizer. ' | |||
| 'You need to install jieba to use JiebaTokenizer. ' | |||
| 'See https://pypi.org/project/rjieba/ for installation.') | |||
| self.jieba = jieba | |||
| self.new_line = self.vocab['\n'] | |||
| @@ -49,7 +53,7 @@ class JiebaBPETokenizer: | |||
| inv_vocab[val] = key | |||
| return inv_vocab | |||
| def tokenize(self, text, is_code=False): | |||
| def tokenize(self, text: str, is_code: bool = False) -> List[int]: | |||
| """ | |||
| """ | |||
| if not is_code: | |||
| @@ -61,7 +65,7 @@ class JiebaBPETokenizer: | |||
| text, is_pretokenized=False, add_special_tokens=True).ids | |||
| def detokenize(self, token_ids): | |||
| text = self.tokenizer.decode(token_ids, skip_special_tokens=False) | |||
| text = self.tokenizer.decode(token_ids, skip_special_tokens=True) | |||
| return text | |||
| @property | |||
| @@ -110,7 +110,8 @@ class DistributedPlug(TorchModel): | |||
| if 'LayerNorm' in name: | |||
| _module.float() | |||
| load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) | |||
| load_model = pre_load( | |||
| mpu.get_model_parallel_rank(), self.model_dir, tag=path_load_tag) | |||
| model_dict = model.module.model.state_dict() | |||
| for key in load_model: | |||
| if key not in model_dict.keys(): | |||
| @@ -5,7 +5,7 @@ from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 | |||
| from modelscope.models.nlp import DistributedGPT3 | |||
| from modelscope.pipelines.base import DistributedPipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextGenerationJiebaPreprocessor | |||
| @@ -30,7 +30,7 @@ class DistributedGPT3Pipeline(DistributedPipeline): | |||
| Extra kwargs passed into the preprocessor's constructor. | |||
| """ | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationJiebaPreprocessor(model, **kwargs) | |||
| preprocessor = TextGenerationJiebaPreprocessor(model) | |||
| super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(preprocessor, 'tokenizer') | |||
| @@ -58,5 +58,6 @@ class DistributedGPT3Pipeline(DistributedPipeline): | |||
| from modelscope.outputs import OutputKeys | |||
| return { | |||
| OutputKeys.TEXT: | |||
| self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) | |||
| self.preprocessor.tokenizer.detokenize( | |||
| inputs.sequences[0].tolist()) | |||
| } | |||
| @@ -30,7 +30,6 @@ class TextGenerationPipeline(Pipeline): | |||
| device: str = 'gpu', | |||
| auto_collate=True, | |||
| first_sequence='sentence', | |||
| sequence_length=128, | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a generation pipeline for prediction. | |||
| @@ -63,10 +62,7 @@ class TextGenerationPipeline(Pipeline): | |||
| if preprocessor is None: | |||
| self.preprocessor = Preprocessor.from_pretrained( | |||
| self.model.model_dir, | |||
| first_sequence=first_sequence, | |||
| sequence_length=sequence_length, | |||
| **kwargs) | |||
| self.model.model_dir, first_sequence=first_sequence, **kwargs) | |||
| self.model.eval() | |||
| self.postprocessor = kwargs.pop('postprocessor', None) | |||
| if self.postprocessor is None and hasattr(self.model, 'model_dir'): | |||
| @@ -192,7 +192,9 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): | |||
| model_dir: str, | |||
| mode: str = ModeKeys.INFERENCE, | |||
| src_txt='src_txt', | |||
| tgt_txt=None): | |||
| tgt_txt=None, | |||
| sequence_length: int = 128, | |||
| use_fast=None): | |||
| from modelscope.models.nlp.gpt3 import JiebaBPETokenizer | |||
| super().__init__(mode, src_txt, tgt_txt) | |||
| if self.tgt_txt is not None: | |||
| @@ -202,6 +204,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): | |||
| self.src_txt = src_txt | |||
| self.tokenizer = JiebaBPETokenizer( | |||
| osp.join(model_dir, 'tokenizer.json')) | |||
| self.max_length = sequence_length | |||
| def decode(self, tokens, **kwargs): | |||
| """Decode the tokens to real text. | |||
| @@ -214,6 +217,14 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): | |||
| """ | |||
| return self.tokenizer.detokenize(tokens) | |||
| def _truncate(self, array: np.ndarray) -> np.ndarray: | |||
| if len(array) < self.max_length: | |||
| return np.pad( | |||
| array, (0, self.max_length - len(array)), | |||
| constant_values=self.tokenizer.eod) | |||
| else: | |||
| return array[:self.max_length] | |||
| def _tokenize_text(self, sequence1, sequence2=None, **kwargs): | |||
| """Tokenize the text. | |||
| @@ -224,10 +235,22 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): | |||
| Returns: | |||
| The encoded sequence. | |||
| """ | |||
| return { | |||
| 'input_ids': | |||
| torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0) | |||
| } | |||
| if self.mode == ModeKeys.INFERENCE: | |||
| return { | |||
| 'input_ids': | |||
| torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0) | |||
| } | |||
| else: | |||
| tokens = self.tokenizer.tokenize(sequence1) | |||
| prompt_length = min(len(tokens), self.max_length - 1) | |||
| if sequence2 is not None: | |||
| tokens += self.tokenizer.tokenize(sequence2) | |||
| tokens = self._truncate(np.array(tokens)) | |||
| return { | |||
| 'tokens': tokens[:-1], | |||
| 'labels': tokens[1:], | |||
| 'prompt_length': prompt_length, | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| @@ -74,7 +74,7 @@ class TextLoggerHook(LoggerHook): | |||
| self._dump_log(trainer.meta) | |||
| def _get_max_memory(self, trainer): | |||
| device = getattr(trainer.model, 'output_device', None) | |||
| device = torch.cuda.current_device() | |||
| mem = torch.cuda.max_memory_allocated(device=device) | |||
| mem_mb = torch.tensor([mem / (1024 * 1024)], | |||
| dtype=torch.int, | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from collections.abc import Mapping | |||
| from typing import List | |||
| import torch | |||
| from megatron import mpu | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models import TorchModel | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| @TRAINERS.register_module(module_name=Trainers.gpt3_trainer) | |||
| class GPT3Trainer(NlpEpochBasedTrainer): | |||
| def rebuild_config(self, cfg: Config): | |||
| super().rebuild_config(cfg) | |||
| cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1)) | |||
| cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1') | |||
| cfg.model.master_port = os.environ.get('MASTER_PORT', '29500') | |||
| return cfg | |||
| def train_step(self, model: TorchModel, inputs: Mapping): | |||
| keys = list(inputs.keys()) | |||
| datatype = torch.int64 | |||
| inputs = mpu.broadcast_data(keys, inputs, datatype) | |||
| return super().train_step(model, inputs) | |||
| def _decode(self, tokens): | |||
| tokenizer = self.eval_preprocessor.tokenizer | |||
| return tokenizer.detokenize(tokens.tolist()) | |||
| def evaluation_step(self, data): | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| if isinstance( | |||
| data, | |||
| Mapping) and not func_receive_dict_inputs(model.generate): | |||
| result = model.generate(**data) | |||
| else: | |||
| result = model.generate(data) | |||
| prompt_length: List[int] = data['prompt_length'] | |||
| result['preds'] = [ | |||
| self._decode(seq[skip_len:]) | |||
| for seq, skip_len in zip(result['sequences'], prompt_length) | |||
| ] | |||
| data['tgts'] = [ | |||
| self._decode(seq[skip_len - 1:]) | |||
| for seq, skip_len in zip(data['labels'], prompt_length) | |||
| ] | |||
| assert len(result['preds']) == len(data['tgts']) | |||
| return result | |||
| @@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| device_name: The final device name. | |||
| """ | |||
| device_name = device if device is not None else 'gpu' | |||
| if self._dist: | |||
| if dist.is_initialized(): | |||
| local_rank = get_local_rank() | |||
| device_name = f'cuda:{local_rank}' | |||
| @@ -137,7 +137,8 @@ def multi_gpu_test(trainer, | |||
| else: | |||
| batch_size = len(data) | |||
| if i >= (data_len // world_size) - 1: | |||
| total_samples = torch.LongTensor([batch_size]).to(model.device) | |||
| total_samples = torch.LongTensor([batch_size | |||
| ]).to(trainer.model.device) | |||
| dist.all_reduce(total_samples, op=dist.reduce_op.SUM) | |||
| total_samples = total_samples.item() | |||
| else: | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| import math | |||
| import os | |||
| import torch | |||
| import torch.distributed as dist | |||
| @@ -36,7 +37,7 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size, | |||
| init_method += master_ip + ':' + master_port | |||
| torch.distributed.init_process_group( | |||
| backend='nccl', | |||
| world_size=world_size, | |||
| world_size=int(os.getenv('WORLD_SIZE', world_size)), | |||
| rank=rank, | |||
| init_method=init_method) | |||
| # Set the model-parallel communicators. | |||
| @@ -55,16 +55,15 @@ def load_checkpoint(model, | |||
| return load_path, client_states | |||
| def _get_ckpt_name(mpu, checkpoints_path, tag): | |||
| mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() | |||
| def _get_ckpt_name(mp_rank, checkpoints_path, tag): | |||
| ckpt_name = os.path.join( | |||
| checkpoints_path, str(tag), | |||
| 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') | |||
| return ckpt_name | |||
| def pre_load(mpu, load_dir, tag=''): | |||
| load_path = _get_ckpt_name(mpu, load_dir, tag) | |||
| def pre_load(mp_rank, load_dir, tag=''): | |||
| load_path = _get_ckpt_name(mp_rank, load_dir, tag) | |||
| checkpoint = torch.load( | |||
| load_path, map_location=lambda storage, loc: storage) | |||
| return checkpoint['module'] | |||
| @@ -107,8 +107,14 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: | |||
| def get_dist_info() -> Tuple[int, int]: | |||
| if dist.is_available() and dist.is_initialized(): | |||
| rank = dist.get_rank() | |||
| world_size = dist.get_world_size() | |||
| try: | |||
| from megatron import mpu | |||
| assert mpu.model_parallel_is_initialized() | |||
| rank = mpu.get_data_parallel_rank() | |||
| world_size = mpu.get_data_parallel_world_size() | |||
| except (ImportError, AssertionError): | |||
| rank = dist.get_rank() | |||
| world_size = dist.get_world_size() | |||
| else: | |||
| rank = 0 | |||
| world_size = 1 | |||
| @@ -120,16 +126,14 @@ def get_local_rank(): | |||
| def is_master(): | |||
| rank, _ = get_dist_info() | |||
| return rank == 0 | |||
| return dist.get_rank() == 0 if dist.is_initialized() else True | |||
| def master_only(func: Callable) -> Callable: | |||
| @functools.wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| rank, _ = get_dist_info() | |||
| if rank == 0: | |||
| if is_master(): | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| @@ -138,12 +142,11 @@ def master_only(func: Callable) -> Callable: | |||
| def make_tmp_dir(): | |||
| """Make sure each rank has the same temporary directory on the distributed mode. | |||
| """ | |||
| rank, world_size = get_dist_info() | |||
| if world_size <= 1: | |||
| if not dist.is_initialized(): | |||
| return tempfile.mkdtemp() | |||
| tmpdir = None | |||
| if rank == 0: | |||
| if is_master(): | |||
| tmpdir = tempfile.mkdtemp() | |||
| dist.barrier() | |||
| @@ -162,7 +165,7 @@ def broadcast(inputs, src): | |||
| Returns: | |||
| Each rank returns the same value as src. | |||
| """ | |||
| rank, _ = get_dist_info() | |||
| rank = dist.get_rank() | |||
| shape_tensor = torch.tensor([0], device='cuda') | |||
| if rank == src: | |||
| @@ -0,0 +1,129 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| class TestFinetuneTextGeneration(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skip | |||
| def test_finetune_poetry(self): | |||
| dataset_dict = MsDataset.load('chinese-poetry-collection') | |||
| train_dataset = dataset_dict['train'].to_hf_dataset().rename_columns( | |||
| {'text1': 'src_txt'}) | |||
| eval_dataset = dataset_dict['test'].to_hf_dataset().rename_columns( | |||
| {'text1': 'src_txt'}) | |||
| max_epochs = 10 | |||
| tmp_dir = './gpt3_poetry' | |||
| num_warmup_steps = 100 | |||
| def noam_lambda(current_step: int): | |||
| current_step += 1 | |||
| return min(current_step**(-0.5), | |||
| current_step * num_warmup_steps**(-1.5)) | |||
| def cfg_modify_fn(cfg): | |||
| cfg.train.lr_scheduler = { | |||
| 'type': 'LambdaLR', | |||
| 'lr_lambda': noam_lambda, | |||
| 'options': { | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4} | |||
| cfg.train.dataloader = { | |||
| 'batch_size_per_gpu': 16, | |||
| 'workers_per_gpu': 1 | |||
| } | |||
| return cfg | |||
| kwargs = dict( | |||
| model='damo/nlp_gpt3_text-generation_1.3B', | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| max_epochs=max_epochs, | |||
| work_dir=tmp_dir, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| # Construct trainer and train | |||
| trainer = build_trainer( | |||
| name=Trainers.gpt3_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| @unittest.skip | |||
| def test_finetune_dureader(self): | |||
| # DuReader_robust-QG is an example data set, | |||
| # users can also use their own data set for training | |||
| dataset_dict = MsDataset.load('DuReader_robust-QG') | |||
| train_dataset = dataset_dict['train'].to_hf_dataset() \ | |||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \ | |||
| .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'}) | |||
| eval_dataset = dataset_dict['validation'].to_hf_dataset() \ | |||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \ | |||
| .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'}) | |||
| max_epochs = 10 | |||
| tmp_dir = './gpt3_dureader' | |||
| num_warmup_steps = 200 | |||
| def noam_lambda(current_step: int): | |||
| current_step += 1 | |||
| return min(current_step**(-0.5), | |||
| current_step * num_warmup_steps**(-1.5)) | |||
| def cfg_modify_fn(cfg): | |||
| cfg.train.lr_scheduler = { | |||
| 'type': 'LambdaLR', | |||
| 'lr_lambda': noam_lambda, | |||
| 'options': { | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4} | |||
| cfg.train.dataloader = { | |||
| 'batch_size_per_gpu': 16, | |||
| 'workers_per_gpu': 1 | |||
| } | |||
| cfg.train.hooks.append({ | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': True, | |||
| 'interval': 1 | |||
| }) | |||
| cfg.preprocessor.sequence_length = 512 | |||
| cfg.model.checkpoint_model_parallel_size = 1 | |||
| return cfg | |||
| kwargs = dict( | |||
| model='damo/nlp_gpt3_text-generation_1.3B', | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| max_epochs=max_epochs, | |||
| work_dir=tmp_dir, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| # Construct trainer and train | |||
| trainer = build_trainer( | |||
| name=Trainers.gpt3_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -80,7 +80,8 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer = build_trainer( | |||
| name=Trainers.text_generation_trainer, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||