Browse Source

[to #42322933] Add GPT-3 tensor parallel finetuning

Add GPT-3 tensor parallel finetuning, adjust some distributed codes to make tensor and data parallel compatible.
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10949507
master^2
hemu.zp 3 years ago
parent
commit
941dbe75cf
20 changed files with 403 additions and 67 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +1
    -1
      modelscope/metrics/text_generation_metric.py
  3. +2
    -2
      modelscope/models/nlp/__init__.py
  4. +2
    -0
      modelscope/models/nlp/gpt3/__init__.py
  5. +125
    -22
      modelscope/models/nlp/gpt3/distributed_gpt3.py
  6. +16
    -6
      modelscope/models/nlp/gpt3/text_generation.py
  7. +7
    -3
      modelscope/models/nlp/gpt3/tokenizer.py
  8. +2
    -1
      modelscope/models/nlp/plug/distributed_plug.py
  9. +4
    -3
      modelscope/pipelines/nlp/distributed_gpt3_pipeline.py
  10. +1
    -5
      modelscope/pipelines/nlp/text_generation_pipeline.py
  11. +28
    -5
      modelscope/preprocessors/nlp/text_generation_preprocessor.py
  12. +1
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  13. +61
    -0
      modelscope/trainers/nlp/gpt3_trainer.py
  14. +1
    -1
      modelscope/trainers/trainer.py
  15. +2
    -1
      modelscope/trainers/utils/inference.py
  16. +2
    -1
      modelscope/utils/nlp/distributed.py
  17. +3
    -4
      modelscope/utils/nlp/load_checkpoint.py
  18. +13
    -10
      modelscope/utils/torch_utils.py
  19. +129
    -0
      tests/trainers/test_finetune_gpt3.py
  20. +2
    -1
      tests/trainers/test_finetune_text_generation.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -349,6 +349,7 @@ class Trainers(object):
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
text_generation_trainer = 'text-generation-trainer'
nlp_plug_trainer = 'nlp-plug-trainer'
gpt3_trainer = 'nlp-gpt3-trainer'

# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'


+ 1
- 1
modelscope/metrics/text_generation_metric.py View File

@@ -44,7 +44,7 @@ class TextGenerationMetric(Metric):
def remove_useless(string: str) -> str:
return string.replace(' ', '').replace('.', '')

return remove_useless(pred) and remove_useless(tgt)
return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0

def evaluate(self):
assert self.preds, 'preds in TextGenerationMetric must not be empty!'


+ 2
- 2
modelscope/models/nlp/__init__.py View File

@@ -18,7 +18,7 @@ if TYPE_CHECKING:
from .csanmt import CsanmtForTranslation
from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model
from .gpt_neo import GPTNeoModel
from .gpt3 import GPT3ForTextGeneration
from .gpt3 import GPT3ForTextGeneration, DistributedGPT3
from .heads import SequenceClassificationHead
from .palm_v2 import PalmForTextGeneration
from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig
@@ -59,7 +59,7 @@ else:
'bart': ['BartForTextErrorCorrection'],
'csanmt': ['CsanmtForTranslation'],
'heads': ['SequenceClassificationHead'],
'gpt3': ['GPT3ForTextGeneration'],
'gpt3': ['GPT3ForTextGeneration', 'DistributedGPT3'],
'structbert': [
'SbertForFaqQuestionAnswering',
'SbertForMaskedLM',


+ 2
- 0
modelscope/models/nlp/gpt3/__init__.py View File

@@ -8,12 +8,14 @@ if TYPE_CHECKING:
from .backbone import GPT3Model
from .text_generation import GPT3ForTextGeneration
from .tokenizer import JiebaBPETokenizer
from .distributed_gpt3 import DistributedGPT3
else:
_import_structure = {
'configuration': ['GPT3Config'],
'backbone': ['GPT3Model'],
'text_generation': ['GPT3ForTextGeneration'],
'tokenizer': ['JiebaBPETokenizer'],
'distributed_gpt3': ['DistributedGPT3'],
}

import sys


+ 125
- 22
modelscope/models/nlp/gpt3/distributed_gpt3.py View File

@@ -13,7 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import math
import os
from os import path as osp
from typing import Callable, Dict, List, Optional, Union

import torch
from megatron import mpu
@@ -25,8 +29,14 @@ from torch import nn
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel

from modelscope.fileio import File
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.gpt3 import GPT3Config
from modelscope.outputs import TextGenerationModelOutput, TokenGeneratorOutput
from modelscope.utils.checkpoint import weights_to_cpu
from modelscope.utils.constant import Tasks
from modelscope.utils.nlp.distributed import initialize_distributed
from modelscope.utils.nlp.load_checkpoint import pre_load
from modelscope.utils.torch_utils import set_random_seed_mpu
@@ -435,7 +445,7 @@ class nullcontext:

def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = F.dropout(x + bias, p=prob, training=training)
out = residual + out
return out

@@ -747,11 +757,9 @@ class GPT3Model(PreTrainedModel):

config_class = GPT3Config

def __init__(self, config, parallel_output=False):
def __init__(self, config):
super().__init__(config)

self.parallel_output = parallel_output

self.language_model = GPT3TransformerLanguageModel(
config, init_method_normal(config.init_method_std),
scaled_init_method_normal(config.init_method_std,
@@ -764,9 +772,7 @@ class GPT3Model(PreTrainedModel):
def build_attention_mask_and_position_ids(tokens):
seq_length = tokens.size(1)
attention_mask = torch.tril(
torch.ones((1, 1, seq_length, seq_length),
dtype=torch.long,
device=tokens.device))
torch.ones((1, 1, seq_length, seq_length), device=tokens.device))
attention_mask = (attention_mask < 0.5)

position_ids = torch.arange(
@@ -780,6 +786,7 @@ class GPT3Model(PreTrainedModel):
attention_mask=None,
position_ids=None,
inference_params=None,
labels=None,
**kwargs):
if attention_mask is None and position_ids is None:
attention_mask, position_ids = \
@@ -797,9 +804,18 @@ class GPT3Model(PreTrainedModel):
# Gather if needed.

output = logits_parallel
if not self.parallel_output:

if labels is None:
output = mpu.gather_from_model_parallel_region(logits_parallel)
return output.transpose(0, 1).contiguous()
# [s b h] => [b s h]
return output.transpose(0, 1).contiguous()
else:
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b s]
loss = loss.transpose(0, 1).contiguous()
return loss


def modify_logits_for_top_k_filtering(logits, top_k):
@@ -911,6 +927,51 @@ class InferenceParams:
new_inference_key_memory, new_inference_value_memory)


def split_into_partitions(tensor, num_partitions, partition_dim, stride):
per_partition_size = mpu.utils.divide(
tensor.size(partition_dim), num_partitions)
per_partition_per_stride_size = mpu.utils.divide(per_partition_size,
stride)
partitions_list = torch.split(
tensor, per_partition_per_stride_size, dim=partition_dim)
partitions = []
for i in range(num_partitions):
partition = torch.cat(
partitions_list[i::num_partitions], dim=partition_dim)
partitions.append(partition)
return partitions


def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model,
partitions: int) -> Dict[str, torch.Tensor]:
if partitions == 1:
return state_dict
rank: int = mpu.get_model_parallel_rank()
for name, parameters in model.named_parameters():
if parameters.shape == state_dict[name].shape:
continue
dim = max(parameters.partition_dim, 0)
stride = parameters.partition_stride
state_dict[name] = split_into_partitions(state_dict[name], partitions,
dim, stride)[rank]
return state_dict


def save_checkpoint(model: torch.nn.Module, filename: str) -> None:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module

checkpoint = {'module': weights_to_cpu(model.state_dict())}
mp_rank = mpu.get_model_parallel_rank()
filename = osp.join(
osp.dirname(filename), 'model',
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')

with io.BytesIO() as f:
torch.save(checkpoint, f)
File.write(f.getvalue(), filename)


class DistributedGPT3(TorchModel):

def __init__(self,
@@ -942,33 +1003,63 @@ class DistributedGPT3(TorchModel):
model = Float16Module(model, self.config)

self.dist_model = model
load_model = pre_load(mpu, model_dir, tag=path_load_tag)

tensor_ws = mpu.get_model_parallel_world_size()
ckpt_ws = kwargs.pop('checkpoint_model_parallel_size', tensor_ws)
ckpt_rank = mpu.get_model_parallel_rank() * ckpt_ws // tensor_ws
load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag)
load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws)

self.dist_model.load_state_dict(load_model)

self.inference_params = None

def forward_step(self, tokens, attention_mask, position_ids):
logits = self.dist_model(
def train(self, mode: bool = True):
if mode:
self.inference_params = None
return super().train(mode)

def forward(self,
tokens,
attention_mask=None,
position_ids=None,
labels=None,
prompt_length=None):
outputs = self.dist_model(
tokens,
attention_mask,
position_ids,
inference_params=self.inference_params)
self.inference_params.sequence_len_offset += tokens.size(1)
return logits
inference_params=self.inference_params,
labels=labels)
if labels is None:
self.inference_params.sequence_len_offset += tokens.size(1)
return TextGenerationModelOutput(logits=outputs)
else:
loss_mask = torch.ones(
tokens.size(), dtype=torch.float, device=tokens.device)

losses = outputs.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

return TextGenerationModelOutput(loss=loss)

def generate(self,
tokens,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False):
lengths = torch.tensor([tokens.size(1)], device=tokens.device)
stop_on_eol=False,
**kwargs):
batch_size = tokens.size(0)
lengths = kwargs.pop(
'prompt_length',
torch.tensor([tokens.size(1)], device=tokens.device))
pads = torch.ones(
1, self.config.tokens_to_generate,
batch_size, self.config.tokens_to_generate,
device=tokens.device).long() * self.config.eod_id
tokens = torch.cat((tokens, pads), dim=-1)

batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length,
@@ -1009,8 +1100,8 @@ class DistributedGPT3(TorchModel):
..., prev_context_length:context_length, :context_length]

# logits will be meanigful only in the last pipeline stage.
logits = self.forward_step(tokens2use, attention_mask2use,
positions2use)
logits = self(tokens2use, attention_mask2use,
positions2use).logits

# Sample.
last_token_logits = logits[:, -1, :]
@@ -1054,4 +1145,16 @@ class DistributedGPT3(TorchModel):
break

tokens = tokens[:, :(context_length + 1)]
return tokens
return TokenGeneratorOutput(sequences=tokens)

def state_dict(self):
return self.dist_model.state_dict()

def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = save_checkpoint,
config: Optional[dict] = None,
**kwargs):
return super().save_pretrained(target_folder, save_checkpoint_names,
save_function, config, **kwargs)

+ 16
- 6
modelscope/models/nlp/gpt3/text_generation.py View File

@@ -1,10 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Dict

from transformers import BertTokenizer

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.models.nlp.gpt3 import GPT3Model
from modelscope.utils.constant import Tasks

__all__ = ['GPT3ForTextGeneration']
@@ -21,11 +24,15 @@ class GPT3ForTextGeneration(TorchModel):
"""
super().__init__(model_dir, *args, **kwargs)

from modelscope.models.nlp.gpt3 import GPT3Model
from transformers import BertTokenizer

self.model = GPT3Model.from_pretrained(model_dir)
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
# Temporarily compatible with DistributedGPT3 and GPT3Model,
# the base/large model based on GPT3Model will be replaced in the future,
# and GPT3Model will be deprecated
if 'model_parallel_size' in kwargs:
from modelscope.models.nlp import DistributedGPT3
self.model = DistributedGPT3(model_dir, **kwargs)
else:
self.model = GPT3Model.from_pretrained(model_dir)
self.tokenizer = BertTokenizer.from_pretrained(model_dir)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model
@@ -43,6 +50,9 @@ class GPT3ForTextGeneration(TorchModel):
return self.model(**input)

def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
if not isinstance(self.model, GPT3Model):
return self.model.generate(**input)

assert 'input_ids' in input, "generate function must accept 'input_ids' key"
input_ids = input['input_ids']
if 'attention_mask' in input:


+ 7
- 3
modelscope/models/nlp/gpt3/tokenizer.py View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from tokenizers import Tokenizer


@@ -25,9 +27,11 @@ class JiebaBPETokenizer:
self.eod_id = self.tokenizer.token_to_id('<|endoftext|>')
try:
import jieba
import logging
jieba.setLogLevel(logging.INFO)
except ImportError:
raise ImportError(
'You need to install rjieba to use JiebaTokenizer. '
'You need to install jieba to use JiebaTokenizer. '
'See https://pypi.org/project/rjieba/ for installation.')
self.jieba = jieba
self.new_line = self.vocab['\n']
@@ -49,7 +53,7 @@ class JiebaBPETokenizer:
inv_vocab[val] = key
return inv_vocab

def tokenize(self, text, is_code=False):
def tokenize(self, text: str, is_code: bool = False) -> List[int]:
"""
"""
if not is_code:
@@ -61,7 +65,7 @@ class JiebaBPETokenizer:
text, is_pretokenized=False, add_special_tokens=True).ids

def detokenize(self, token_ids):
text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
return text

@property


+ 2
- 1
modelscope/models/nlp/plug/distributed_plug.py View File

@@ -110,7 +110,8 @@ class DistributedPlug(TorchModel):
if 'LayerNorm' in name:
_module.float()

load_model = pre_load(mpu, self.model_dir, tag=path_load_tag)
load_model = pre_load(
mpu.get_model_parallel_rank(), self.model_dir, tag=path_load_tag)
model_dict = model.module.model.state_dict()
for key in load_model:
if key not in model_dict.keys():


+ 4
- 3
modelscope/pipelines/nlp/distributed_gpt3_pipeline.py View File

@@ -5,7 +5,7 @@ from typing import Any, Dict
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3
from modelscope.models.nlp import DistributedGPT3
from modelscope.pipelines.base import DistributedPipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationJiebaPreprocessor
@@ -30,7 +30,7 @@ class DistributedGPT3Pipeline(DistributedPipeline):
Extra kwargs passed into the preprocessor's constructor.
"""
if preprocessor is None:
preprocessor = TextGenerationJiebaPreprocessor(model, **kwargs)
preprocessor = TextGenerationJiebaPreprocessor(model)
super().__init__(model, preprocessor=preprocessor, **kwargs)
assert hasattr(preprocessor, 'tokenizer')

@@ -58,5 +58,6 @@ class DistributedGPT3Pipeline(DistributedPipeline):
from modelscope.outputs import OutputKeys
return {
OutputKeys.TEXT:
self.preprocessor.tokenizer.detokenize(inputs[0].tolist())
self.preprocessor.tokenizer.detokenize(
inputs.sequences[0].tolist())
}

+ 1
- 5
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -30,7 +30,6 @@ class TextGenerationPipeline(Pipeline):
device: str = 'gpu',
auto_collate=True,
first_sequence='sentence',
sequence_length=128,
**kwargs):
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.

@@ -63,10 +62,7 @@ class TextGenerationPipeline(Pipeline):

if preprocessor is None:
self.preprocessor = Preprocessor.from_pretrained(
self.model.model_dir,
first_sequence=first_sequence,
sequence_length=sequence_length,
**kwargs)
self.model.model_dir, first_sequence=first_sequence, **kwargs)
self.model.eval()
self.postprocessor = kwargs.pop('postprocessor', None)
if self.postprocessor is None and hasattr(self.model, 'model_dir'):


+ 28
- 5
modelscope/preprocessors/nlp/text_generation_preprocessor.py View File

@@ -192,7 +192,9 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
model_dir: str,
mode: str = ModeKeys.INFERENCE,
src_txt='src_txt',
tgt_txt=None):
tgt_txt=None,
sequence_length: int = 128,
use_fast=None):
from modelscope.models.nlp.gpt3 import JiebaBPETokenizer
super().__init__(mode, src_txt, tgt_txt)
if self.tgt_txt is not None:
@@ -202,6 +204,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
self.src_txt = src_txt
self.tokenizer = JiebaBPETokenizer(
osp.join(model_dir, 'tokenizer.json'))
self.max_length = sequence_length

def decode(self, tokens, **kwargs):
"""Decode the tokens to real text.
@@ -214,6 +217,14 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
"""
return self.tokenizer.detokenize(tokens)

def _truncate(self, array: np.ndarray) -> np.ndarray:
if len(array) < self.max_length:
return np.pad(
array, (0, self.max_length - len(array)),
constant_values=self.tokenizer.eod)
else:
return array[:self.max_length]

def _tokenize_text(self, sequence1, sequence2=None, **kwargs):
"""Tokenize the text.

@@ -224,10 +235,22 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
Returns:
The encoded sequence.
"""
return {
'input_ids':
torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0)
}
if self.mode == ModeKeys.INFERENCE:
return {
'input_ids':
torch.tensor(self.tokenizer.tokenize(sequence1)).unsqueeze_(0)
}
else:
tokens = self.tokenizer.tokenize(sequence1)
prompt_length = min(len(tokens), self.max_length - 1)
if sequence2 is not None:
tokens += self.tokenizer.tokenize(sequence2)
tokens = self._truncate(np.array(tokens))
return {
'tokens': tokens[:-1],
'labels': tokens[1:],
'prompt_length': prompt_length,
}


@PREPROCESSORS.register_module(


+ 1
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -74,7 +74,7 @@ class TextLoggerHook(LoggerHook):
self._dump_log(trainer.meta)

def _get_max_memory(self, trainer):
device = getattr(trainer.model, 'output_device', None)
device = torch.cuda.current_device()
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,


+ 61
- 0
modelscope/trainers/nlp/gpt3_trainer.py View File

@@ -0,0 +1,61 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from collections.abc import Mapping
from typing import List

import torch
from megatron import mpu

from modelscope.metainfo import Trainers
from modelscope.models import TorchModel
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
from modelscope.utils.config import Config
from modelscope.utils.file_utils import func_receive_dict_inputs


@TRAINERS.register_module(module_name=Trainers.gpt3_trainer)
class GPT3Trainer(NlpEpochBasedTrainer):

def rebuild_config(self, cfg: Config):
super().rebuild_config(cfg)
cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1))
cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
cfg.model.master_port = os.environ.get('MASTER_PORT', '29500')
return cfg

def train_step(self, model: TorchModel, inputs: Mapping):
keys = list(inputs.keys())
datatype = torch.int64
inputs = mpu.broadcast_data(keys, inputs, datatype)
return super().train_step(model, inputs)

def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer
return tokenizer.detokenize(tokens.tolist())

def evaluation_step(self, data):
model = self.model.module if self._dist else self.model
model.eval()

with torch.no_grad():
if isinstance(
data,
Mapping) and not func_receive_dict_inputs(model.generate):
result = model.generate(**data)
else:
result = model.generate(data)

prompt_length: List[int] = data['prompt_length']
result['preds'] = [
self._decode(seq[skip_len:])
for seq, skip_len in zip(result['sequences'], prompt_length)
]
data['tgts'] = [
self._decode(seq[skip_len - 1:])
for seq, skip_len in zip(data['labels'], prompt_length)
]
assert len(result['preds']) == len(data['tgts'])

return result

+ 1
- 1
modelscope/trainers/trainer.py View File

@@ -236,7 +236,7 @@ class EpochBasedTrainer(BaseTrainer):
device_name: The final device name.
"""
device_name = device if device is not None else 'gpu'
if self._dist:
if dist.is_initialized():
local_rank = get_local_rank()
device_name = f'cuda:{local_rank}'



+ 2
- 1
modelscope/trainers/utils/inference.py View File

@@ -137,7 +137,8 @@ def multi_gpu_test(trainer,
else:
batch_size = len(data)
if i >= (data_len // world_size) - 1:
total_samples = torch.LongTensor([batch_size]).to(model.device)
total_samples = torch.LongTensor([batch_size
]).to(trainer.model.device)
dist.all_reduce(total_samples, op=dist.reduce_op.SUM)
total_samples = total_samples.item()
else:


+ 2
- 1
modelscope/utils/nlp/distributed.py View File

@@ -14,6 +14,7 @@
# limitations under the License.

import math
import os

import torch
import torch.distributed as dist
@@ -36,7 +37,7 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size,
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend='nccl',
world_size=world_size,
world_size=int(os.getenv('WORLD_SIZE', world_size)),
rank=rank,
init_method=init_method)
# Set the model-parallel communicators.


+ 3
- 4
modelscope/utils/nlp/load_checkpoint.py View File

@@ -55,16 +55,15 @@ def load_checkpoint(model,
return load_path, client_states


def _get_ckpt_name(mpu, checkpoints_path, tag):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
def _get_ckpt_name(mp_rank, checkpoints_path, tag):
ckpt_name = os.path.join(
checkpoints_path, str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name


def pre_load(mpu, load_dir, tag=''):
load_path = _get_ckpt_name(mpu, load_dir, tag)
def pre_load(mp_rank, load_dir, tag=''):
load_path = _get_ckpt_name(mp_rank, load_dir, tag)
checkpoint = torch.load(
load_path, map_location=lambda storage, loc: storage)
return checkpoint['module']


+ 13
- 10
modelscope/utils/torch_utils.py View File

@@ -107,8 +107,14 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:

def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
try:
from megatron import mpu
assert mpu.model_parallel_is_initialized()
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
except (ImportError, AssertionError):
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
@@ -120,16 +126,14 @@ def get_local_rank():


def is_master():
rank, _ = get_dist_info()
return rank == 0
return dist.get_rank() == 0 if dist.is_initialized() else True


def master_only(func: Callable) -> Callable:

@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
if is_master():
return func(*args, **kwargs)

return wrapper
@@ -138,12 +142,11 @@ def master_only(func: Callable) -> Callable:
def make_tmp_dir():
"""Make sure each rank has the same temporary directory on the distributed mode.
"""
rank, world_size = get_dist_info()
if world_size <= 1:
if not dist.is_initialized():
return tempfile.mkdtemp()

tmpdir = None
if rank == 0:
if is_master():
tmpdir = tempfile.mkdtemp()

dist.barrier()
@@ -162,7 +165,7 @@ def broadcast(inputs, src):
Returns:
Each rank returns the same value as src.
"""
rank, _ = get_dist_info()
rank = dist.get_rank()
shape_tensor = torch.tensor([0], device='cuda')

if rank == src:


+ 129
- 0
tests/trainers/test_finetune_gpt3.py View File

@@ -0,0 +1,129 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer


class TestFinetuneTextGeneration(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skip
def test_finetune_poetry(self):
dataset_dict = MsDataset.load('chinese-poetry-collection')
train_dataset = dataset_dict['train'].to_hf_dataset().rename_columns(
{'text1': 'src_txt'})
eval_dataset = dataset_dict['test'].to_hf_dataset().rename_columns(
{'text1': 'src_txt'})
max_epochs = 10
tmp_dir = './gpt3_poetry'

num_warmup_steps = 100

def noam_lambda(current_step: int):
current_step += 1
return min(current_step**(-0.5),
current_step * num_warmup_steps**(-1.5))

def cfg_modify_fn(cfg):
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': noam_lambda,
'options': {
'by_epoch': False
}
}
cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
cfg.train.dataloader = {
'batch_size_per_gpu': 16,
'workers_per_gpu': 1
}
return cfg

kwargs = dict(
model='damo/nlp_gpt3_text-generation_1.3B',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
max_epochs=max_epochs,
work_dir=tmp_dir,
cfg_modify_fn=cfg_modify_fn)

# Construct trainer and train
trainer = build_trainer(
name=Trainers.gpt3_trainer, default_args=kwargs)
trainer.train()

@unittest.skip
def test_finetune_dureader(self):
# DuReader_robust-QG is an example data set,
# users can also use their own data set for training
dataset_dict = MsDataset.load('DuReader_robust-QG')

train_dataset = dataset_dict['train'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
.map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})
eval_dataset = dataset_dict['validation'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
.map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})

max_epochs = 10
tmp_dir = './gpt3_dureader'

num_warmup_steps = 200

def noam_lambda(current_step: int):
current_step += 1
return min(current_step**(-0.5),
current_step * num_warmup_steps**(-1.5))

def cfg_modify_fn(cfg):
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': noam_lambda,
'options': {
'by_epoch': False
}
}
cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
cfg.train.dataloader = {
'batch_size_per_gpu': 16,
'workers_per_gpu': 1
}
cfg.train.hooks.append({
'type': 'EvaluationHook',
'by_epoch': True,
'interval': 1
})
cfg.preprocessor.sequence_length = 512
cfg.model.checkpoint_model_parallel_size = 1
return cfg

kwargs = dict(
model='damo/nlp_gpt3_text-generation_1.3B',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
max_epochs=max_epochs,
work_dir=tmp_dir,
cfg_modify_fn=cfg_modify_fn)

# Construct trainer and train
trainer = build_trainer(
name=Trainers.gpt3_trainer, default_args=kwargs)
trainer.train()


if __name__ == '__main__':
unittest.main()

+ 2
- 1
tests/trainers/test_finetune_text_generation.py View File

@@ -80,7 +80,8 @@ class TestFinetuneTextGeneration(unittest.TestCase):
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer = build_trainer(
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)


Loading…
Cancel
Save