Browse Source

[to #42322933] plug finetune

plug finetune :已在du reader- robust数据集上回归至最佳结果
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10916382
master^2
suluyan.sly yingda.chen 2 years ago
parent
commit
1394019102
14 changed files with 837 additions and 119 deletions
  1. +4
    -0
      modelscope/metainfo.py
  2. +88
    -0
      modelscope/models/nlp/plug/AnnealingLR.py
  3. +112
    -0
      modelscope/models/nlp/plug/backbone.py
  4. +1
    -1
      modelscope/models/nlp/plug/configuration.py
  5. +24
    -111
      modelscope/models/nlp/plug/distributed_plug.py
  6. +225
    -0
      modelscope/models/nlp/plug/generator.py
  7. +7
    -1
      modelscope/preprocessors/nlp/text_generation_preprocessor.py
  8. +1
    -1
      modelscope/trainers/hooks/__init__.py
  9. +4
    -2
      modelscope/trainers/hooks/checkpoint_hook.py
  10. +116
    -0
      modelscope/trainers/hooks/deepspeed_hook.py
  11. +2
    -1
      modelscope/trainers/hooks/logger/text_logger_hook.py
  12. +195
    -0
      modelscope/trainers/nlp/plug_trainer.py
  13. +5
    -2
      modelscope/trainers/trainer.py
  14. +53
    -0
      tests/trainers/test_plug_finetune_text_generation.py

+ 4
- 0
modelscope/metainfo.py View File

@@ -338,6 +338,7 @@ class Trainers(object):
nlp_veco_trainer = 'nlp-veco-trainer'
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
text_generation_trainer = 'text-generation-trainer'
nlp_plug_trainer = 'nlp-plug-trainer'

# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
@@ -500,6 +501,9 @@ class Hooks(object):
# CLIP logit_scale clamp
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'

# train
DeepspeedHook = 'DeepspeedHook'


class LR_Schedulers(object):
"""learning rate scheduler is defined here


+ 88
- 0
modelscope/models/nlp/plug/AnnealingLR.py View File

@@ -0,0 +1,88 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""

import math

import torch
from torch.optim.lr_scheduler import _LRScheduler


class AnnealingLR(_LRScheduler):
"""Anneals the learning rate from start to zero along a cosine curve."""

DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']

def __init__(self,
optimizer,
start_lr,
warmup_iter,
num_iters,
decay_style=None,
last_iter=-1):
self.optimizer = optimizer
self.start_lr = start_lr
self.warmup_iter = warmup_iter
self._step_count = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style,
str) else None
self.step(self._step_count)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style)

def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
if self.warmup_iter > 0 and self._step_count <= self.warmup_iter:
return float(self.start_lr) * self._step_count / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr * ((
self.end_iter - # noqa W504
(self._step_count - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
return self.start_lr / 2.0 * (
math.cos(math.pi * (self._step_count - self.warmup_iter)
/ self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# TODO: implement exponential decay
return self.start_lr
else:
return self.start_lr

def step(self, step_num=None):
if step_num is None:
step_num = self._step_count + 1
self._step_count = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr

def state_dict(self):
sd = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'_step_count': self._step_count,
'decay_style': self.decay_style,
'end_iter': self.end_iter
}
return sd

def load_state_dict(self, sd):
self.start_lr = sd['start_lr']
self.warmup_iter = sd['warmup_iter']
self._step_count = sd['_step_count']
self.end_iter = sd['end_iter']
self.decay_style = sd['decay_style']
self.step(self._step_count)

+ 112
- 0
modelscope/models/nlp/plug/backbone.py View File

@@ -1009,6 +1009,118 @@ class PlugModel(torch.nn.Module):
sequence_output=sequence_output,
parallel_output=parallel_output)

@staticmethod
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
# conversational-ai-with-transfer-learning-2d818ac26313

if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits

def generate(self, input, out_length=128, model_cfg=None, *kwargs):
device = torch.cuda.current_device()
batch_size = input['input_ids'].shape[0]
tokens = input['input_ids'].view(1, -1).contiguous().to(device)
dec_input_ids = input['dec_input_ids'].to(device)
attention_mask = input['attention_mask'].to(device)
self.model.eval()
with torch.no_grad():
# Only supports batch_size=1
all_generate_tokens = []
generate_tokens = []
counter = 0
sequence_output = None
vocab_size = self.config.original_vocab_size
sep_token_idx = 102 # index of [SEP] token in BertTokenizer
while counter < out_length:
if counter % 128 == 0 and counter != 0:
# Sliding window
generate_tokens.append(sep_token_idx)
start = (tokens == sep_token_idx).nonzero(
as_tuple=True)[-1]
if start + len(generate_tokens) >= 512:
tokens = torch.cat([
tokens[:start],
torch.cuda.LongTensor(generate_tokens)
], -1)[-512:]
else:
tokens[0][start:start + len(generate_tokens
)] = torch.cuda.LongTensor(
generate_tokens)

attention_mask = (tokens != 0)
dec_input_ids = input['dec_input_ids'].to(device)
generate_tokens = []
sequence_output = None

position_ids = torch.full([batch_size, 1],
len(generate_tokens),
dtype=torch.long,
device=device)
_, logits, sequence_output = self.model(
tokens,
None,
attention_mask,
dec_input_ids,
attention_mask,
position_ids,
is_infer=True,
sequence_output=sequence_output,
parallel_output=False)
logits = logits[:, -1, :]
logits = logits / model_cfg['temperature']
logits = self.top_k_logits(
logits, top_k=model_cfg['top_k'], top_p=model_cfg['top_p'])
log_probs = F.softmax(logits, dim=-1)
prev = torch.argmax(log_probs, 1).unsqueeze(1)
# prev = torch.multinomial(log_probs, num_samples=1)
prev_token = prev[0].item()
if prev_token >= vocab_size:
prev_token = 100
prev[0] = 100
if prev_token == 102 and len(all_generate_tokens) > int(
max(1, out_length) * 0.8):
break
if prev_token == 102:
counter += 1
continue
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
generate_tokens.append(prev_token)
all_generate_tokens.append(prev_token)
counter += 1

generate_context = []
for token in all_generate_tokens:
if generate_context and generate_context[
-1] == 100 and token == 100:
continue
else:
generate_context.append(token)
return {'generate_context': generate_context}

def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)


+ 1
- 1
modelscope/models/nlp/plug/configuration.py View File

@@ -225,7 +225,7 @@ class PlugNLGConfig(PlugNLUConfig):
fp32_layernorm=True,
fp32_embedding=False,
fp32_tokentypes=False,
layernorm_epsilon=1e-5,
layernorm_epsilon=1e-12,
attn_separate=False,
**kwargs):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)


+ 24
- 111
modelscope/models/nlp/plug/distributed_plug.py View File

@@ -75,7 +75,7 @@ class DistributedPlug(TorchModel):
seed = 42 if 'seed' not in kwargs else kwargs['seed']
set_random_seed_mpu(seed)
self.iteration = 0
self.dist_model = self.initialize_model(path_load_tag='model')
self.model = self.initialize_model(path_load_tag='model')

def initialize_model(self, path_load_tag='model'):
"""Build the model."""
@@ -120,115 +120,28 @@ class DistributedPlug(TorchModel):
model.module.model.load_state_dict(load_model, strict=False)
return model

@staticmethod
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
# conversational-ai-with-transfer-learning-2d818ac26313

if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits
def forward(self,
input_tokens,
token_type_ids=None,
attention_mask=None,
target_tokens=None,
position_ids=None,
decode_attention_mask=None,
checkpoint_activations=False,
is_infer=False,
sequence_output=None,
parallel_output=True):
return self.model(
input_tokens,
token_type_ids,
attention_mask,
target_tokens,
position_ids,
decode_attention_mask,
checkpoint_activations=checkpoint_activations,
is_infer=is_infer,
sequence_output=sequence_output,
parallel_output=parallel_output)

def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs):
device = torch.cuda.current_device()
batch_size = input['input_ids'].shape[0]
tokens = input['input_ids'].view(1, -1).contiguous().to(device)
dec_input_ids = input['dec_input_ids'].to(device)
attention_mask = input['attention_mask'].to(device)
self.dist_model.eval()
with torch.no_grad():
# Only supports batch_size=1
all_generate_tokens = []
generate_tokens = []
counter = 0
sequence_output = None
vocab_size = self.config.original_vocab_size
sep_token_idx = 102 # index of [SEP] token in BertTokenizer
while counter < out_length:
if counter % 128 == 0 and counter != 0:
# Sliding window
generate_tokens.append(sep_token_idx)
start = (tokens == sep_token_idx).nonzero(
as_tuple=True)[-1]
if start + len(generate_tokens) >= 512:
tokens = torch.cat([
tokens[:start],
torch.cuda.LongTensor(generate_tokens)
], -1)[-512:]
else:
tokens[0][start:start + len(generate_tokens
)] = torch.cuda.LongTensor(
generate_tokens)

attention_mask = (tokens != 0)
dec_input_ids = input['dec_input_ids'].to(device)
generate_tokens = []
sequence_output = None

position_ids = torch.full([batch_size, 1],
len(generate_tokens),
dtype=torch.long,
device=device)
_, logits, sequence_output = self.dist_model(
tokens,
None,
attention_mask,
dec_input_ids,
attention_mask,
position_ids,
is_infer=True,
sequence_output=sequence_output,
parallel_output=False)
logits = logits[:, -1, :]
logits = logits / self.model_cfg['temperature']
logits = self.top_k_logits(
logits,
top_k=self.model_cfg['top_k'],
top_p=self.model_cfg['top_p'])
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
prev_token = prev[0].item()
if prev_token >= vocab_size:
prev_token = 100
prev[0] = 100
if prev_token == 102 and len(all_generate_tokens) > int(
max(1, out_length) * 0.8):
break
if prev_token == 102:
counter += 1
continue
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
generate_tokens.append(prev_token)
all_generate_tokens.append(prev_token)
counter += 1

generate_context = []
for token in all_generate_tokens:
if generate_context and generate_context[
-1] == 100 and token == 100:
continue
else:
generate_context.append(token)
return {'generate_context': generate_context}
return self.model.generate(input, out_length, self.model_cfg, *kwargs)

+ 225
- 0
modelscope/models/nlp/plug/generator.py View File

@@ -0,0 +1,225 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch


class TextGenerator(object):

def __init__(self,
model,
vocab,
symbols,
global_scorer=None,
logger=None,
dump_beam=''):
self.alpha = 0.6

self.logger = logger
self.cuda = (torch.cuda.device_count() > 0)

self.model = model
# TODO generator
self.vocab = vocab
self.symbols = symbols
self.start_token = 101 # ['[PAD]']
self.end_token = 102 # '[PAD]']

self.global_scorer = global_scorer
self.beam_size = 5
self.min_length = 5
self.max_length = 384

self.dump_beam = dump_beam

# for debugging
self.beam_trace = self.dump_beam != ''
self.beam_accum = None

if self.beam_trace:
self.beam_accum = {
'predicted_ids': [],
'beam_parent_ids': [],
'scores': [],
'log_probs': []
}

def _build_target_tokens(self, pred):
tokens = []
for tok in pred:
tok = int(tok)
tokens.append(tok)
if tokens[-1] == self.end_token:
tokens = tokens[:-1]
break
tokens = [t for t in tokens if t < len(self.vocab)]
tokens = self.vocab.DecodeIds(tokens).split(' ')
return tokens

def tile(self, x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x

def translate_batch(self, encoder_inputs, fast=False):
with torch.no_grad():
return self._fast_translate_batch(
encoder_inputs, self.max_length, min_length=self.min_length)

def _fast_translate_batch(self, encoder_inputs, max_length, min_length=0):

assert not self.dump_beam

beam_size = self.beam_size
tokens, types, padding_mask = encoder_inputs
batch_size = tokens.size(0)
device = tokens.device
tmp_alive_seq = torch.full([batch_size, 1],
self.start_token,
dtype=torch.long,
device=device)
prediction_scores, dec_feat_seq, sequence_output = self.model(
tokens,
types,
padding_mask,
tmp_alive_seq,
None,
None,
checkpoint_activations=False,
is_infer=True,
parallel_output=False,
sequence_output=None)
src_features = sequence_output

src_features = self.tile(src_features, beam_size, dim=0)
attention_mask = self.tile(padding_mask, beam_size, dim=0)
batch_offset = torch.arange(
batch_size, dtype=torch.long, device=device)
beam_offset = torch.arange(
0,
batch_size * beam_size,
step=beam_size,
dtype=torch.long,
device=device)
alive_seq = torch.full([batch_size * beam_size, 1],
self.start_token,
dtype=torch.long,
device=device)

# Give full probability to the first beam on the first step.
topk_log_probs = (
torch.tensor(
[0.0] + [float('-inf')] * (beam_size - 1),
device=device).repeat(batch_size))

# Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] # noqa: F812

results = {}
results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812
results['scores'] = [[] for _ in range(batch_size)] # noqa: F812
results['gold_score'] = [0] * batch_size
results['batch'] = []
dec_attn_mask = None
dec_position_ids = None

for step in range(max_length):
prediction_scores, dec_feat_seq, _ = self.model(
tokens,
types,
attention_mask,
alive_seq,
dec_position_ids,
dec_attn_mask,
checkpoint_activations=False,
is_infer=True,
parallel_output=False,
sequence_output=src_features)

dec_feat_seq = dec_feat_seq[:, -1, :]
vocab_size = dec_feat_seq.size(-1)
log_probs = torch.log(
torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1))

if step < min_length:
log_probs[:, self.end_token] = -1e20
log_probs += topk_log_probs.view(-1).unsqueeze(1)

alpha = self.alpha # global_scorer.alpha
length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
curr_scores = log_probs / length_penalty

curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
topk_log_probs = topk_scores * length_penalty

# Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size, rounding_mode='trunc')
topk_ids = topk_ids.fmod(vocab_size)

# Map beam_index to batch_index in the flat representation.
batch_index = (
topk_beam_index
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
select_indices = batch_index.view(-1)

# Append last prediction.
alive_seq = torch.cat([
alive_seq.index_select(0, select_indices),
topk_ids.view(-1, 1)
], -1)

is_finished = topk_ids.eq(self.end_token)
if step + 1 == max_length:
is_finished.fill_(1) # self.end_token)
# End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(1) # self.end_token)
# Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1) # self.end_token)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append(
(topk_scores[i, j], predictions[i, j, 1:]))
# If the batch reached the end, save the n_best hypotheses.
if end_condition[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True)
score, pred = best_hyp[0]
results['scores'][b].append(score)
results['predictions'][b].append(pred)
non_finished = end_condition.eq(0).nonzero().view(-1)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
alive_seq = predictions.index_select(0, non_finished) \
.view(-1, alive_seq.size(-1))
# Reorder states.
select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices)
attention_mask = attention_mask.index_select(0, select_indices)

return results

+ 7
- 1
modelscope/preprocessors/nlp/text_generation_preprocessor.py View File

@@ -122,6 +122,8 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
False)
kwargs['max_length'] = sequence_length
self.src_length = kwargs['max_length']
self.tgt_length = kwargs.pop('target_max_length', kwargs['max_length'])
model_type = None
if model_dir is not None:
model_type = get_model_type(model_dir)
@@ -154,10 +156,14 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None

output = self.nlp_tokenizer(sequence1, **kwargs)

if self.mode != ModeKeys.INFERENCE:
if sequence2 is not None:
self.nlp_tokenizer.tokenize_kwargs[
'max_length'] = self.tgt_length
labels = self.nlp_tokenizer(sequence2)['input_ids']
self.nlp_tokenizer.tokenize_kwargs[
'max_length'] = self.src_length

src_input_ids = output['input_ids']
src_attention_mask = output['attention_mask']
else:


+ 1
- 1
modelscope/trainers/hooks/__init__.py View File

@@ -25,7 +25,7 @@ else:
'hook': ['Hook'],
'iter_timer_hook': ['IterTimerHook'],
'logger': ['TensorboardHook', 'TextLoggerHook'],
'lr_scheduler_hook': ['LrSchedulerHook'],
'lr_scheduler_hook': ['LrSchedulerHook', 'NoneLrSchedulerHook'],
'optimizer_hook': [
'ApexAMPOptimizerHook', 'NoneOptimizerHook', 'OptimizerHook',
'TorchAMPOptimizerHook'


+ 4
- 2
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -104,7 +104,8 @@ class CheckpointHook(Hook):
return

if self._should_save(trainer):
if is_master():
if is_master() or trainer.cfg.model.get('model_parallel_size',
1) != 1:
self.logger.info(
f'Saving checkpoint at {trainer.epoch + 1} epoch')
self._save_checkpoint(trainer)
@@ -260,7 +261,8 @@ class CheckpointHook(Hook):
return

if self._should_save(trainer):
if is_master():
if is_master() or trainer.cfg.model.get('model_parallel_size',
1) != 1:
self.logger.info(
f'Saving checkpoint at {trainer.iter + 1} iterations')
self._save_checkpoint(trainer)


+ 116
- 0
modelscope/trainers/hooks/deepspeed_hook.py View File

@@ -0,0 +1,116 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from types import MethodType

import deepspeed
from megatron import mpu

from modelscope.metainfo import Hooks
from modelscope.trainers.hooks import (BestCkptSaverHook, CheckpointHook,
LrSchedulerHook, NoneLrSchedulerHook,
NoneOptimizerHook, OptimizerHook)
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.torch_utils import is_master
from .builder import HOOKS
from .hook import Hook
from .priority import Priority


@HOOKS.register_module(module_name=Hooks.DeepspeedHook)
class DeepspeedHook(Hook):
PRIORITY = Priority.VERY_HIGH

def __init__(self,
deepspeed_activation_checkpointing=True,
save_zero_checkpoint=False,
loss_key='loss'):
self.save_zero_checkpoint = save_zero_checkpoint
self.loss_key = loss_key
self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing

def before_run(self, trainer):
# deepspeed init
args = trainer.cfg.train
args.deepspeed_config = os.path.join(trainer.model_dir,
args.deepspeed_config)

trainer.model, _, _, _ = deepspeed.initialize(
model=trainer.model,
optimizer=trainer.optimizer,
args=args,
lr_scheduler=trainer.lr_scheduler,
mpu=mpu,
dist_init_required=False)
trainer.model.save_zero_checkpoint = self.save_zero_checkpoint

if self.deepspeed_activation_checkpointing:
model = trainer.model
while hasattr(model, 'module'):
model = model.module
deepspeed.checkpointing.configure(
mpu,
deepspeed_config=args.deepspeed_config,
num_checkpoints=model.config.num_hidden_layers)

mpu.checkpoint = deepspeed.checkpointing.checkpoint
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed

# modify hooks
for i, hook in enumerate(trainer._hooks):
# backward & step
if isinstance(hook, OptimizerHook):
trainer._hooks[i] = NoneOptimizerHook()
if isinstance(hook, LrSchedulerHook):
trainer._hooks[i] = NoneLrSchedulerHook()

# save checkpoint
if isinstance(hook, CheckpointHook):

def _save_checkpoint(self, trainer):
if self.by_epoch:
cur_save_dir = os.path.join(
self.save_dir,
f'{LogKeys.EPOCH}_{trainer.epoch + 1}')
else:
cur_save_dir = os.path.join(
self.save_dir,
f'{LogKeys.ITER}_{trainer.iter + 1}')
if (self.is_last_epoch(trainer)
and self.by_epoch) or (self.is_last_iter(trainer)
and not self.by_epoch):
cur_save_dir = os.path.join(self.save_dir,
ModelFile.TRAIN_OUTPUT_DIR)
trainer.model.save_checkpoint(cur_save_dir)

trainer._hooks[i]._save_checkpoint = MethodType(
_save_checkpoint, trainer._hooks[i])

if isinstance(hook, BestCkptSaverHook):

def _save_checkpoint(self, trainer):
if self.by_epoch:
cur_save_dir = os.path.join(
self.save_dir,
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
)
else:
cur_save_dir = os.path.join(
self.save_dir,
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth'
)
trainer.model.save_checkpoint(cur_save_dir)
self._best_ckpt_file = cur_save_dir

trainer._hooks[i]._save_checkpoint = MethodType(
_save_checkpoint, trainer._hooks[i])

def after_train_iter(self, trainer):
# The `trainer.model` here is actually a deepspeed engine object.
# backward step
loss = trainer.train_outputs[self.loss_key]
trainer.model.backward(loss)

# update parameters
trainer.model.step()

+ 2
- 1
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -80,7 +80,8 @@ class TextLoggerHook(LoggerHook):
dtype=torch.int,
device=device)
_, world_size = get_dist_info()
if world_size > 1:
if world_size > 1 and getattr(trainer.cfg.model, 'model_parallel_size',
1) < world_size:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb.item()



+ 195
- 0
modelscope/trainers/nlp/plug_trainer.py View File

@@ -0,0 +1,195 @@
import os
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from megatron import mpu
from torch import nn

from modelscope.metainfo import Trainers
from modelscope.models.base import Model, TorchModel
from modelscope.models.nlp.plug import DistributedPlug
from modelscope.models.nlp.plug.backbone import BertLayerNorm
from modelscope.models.nlp.plug.generator import TextGenerator
from modelscope.utils.constant import ModeKeys
from ..base import TRAINERS
from ..nlp_trainer import NlpEpochBasedTrainer


@TRAINERS.register_module(module_name=Trainers.nlp_plug_trainer)
class PlugTrainer(NlpEpochBasedTrainer):

def build_model(self) -> Union[nn.Module, TorchModel]:
rank = int(os.environ.get('LOCAL_RANK', -1))
master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
master_port = os.environ.get('MASTER_PORT', '29500')
model = DistributedPlug(
self.model_dir,
rank,
master_ip=master_ip,
master_port=master_port,
**self.cfg.model)
return model.model

def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
return DDP(model)

def _get_params_for_weight_decay_optimization(self, module):

weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend([
p for p in list(module_._parameters.values())
if p is not None
])
else:
weight_decay_params['params'].extend([
p for n, p in list(module_._parameters.items())
if p is not None and 'mask_score' not in n
and 'mask' not in n and n != 'bias'
])
no_weight_decay_params['params'].extend([
p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'
])

return weight_decay_params, no_weight_decay_params

def create_optimizer_and_scheduler(self):
optimizer, lr_scheduler = self.optimizers
optimizer_cfg = self.cfg.train.get('optimizer', None)
# optim_options = {}
if optimizer_cfg is not None:
optim_options = optimizer_cfg.pop('options', {})
from deepspeed.ops.adam import DeepSpeedCPUAdam
model = self.model

embeddings = model.module.module.model.bert.embeddings
layers = model.module.module.model.bert.encoder.layer
dec_layers = model.module.module.model.decoder.decoder
param_groups = []
param_groups += list(
self._get_params_for_weight_decay_optimization(layers))
param_groups += list(
self._get_params_for_weight_decay_optimization(embeddings))
param_groups += list(
self._get_params_for_weight_decay_optimization(dec_layers))

for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
optimizer = DeepSpeedCPUAdam(
param_groups,
lr=optimizer_cfg.lr,
weight_decay=optimizer_cfg.weight_decay)

lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)

if lr_scheduler_cfg is not None:
assert optimizer is not None
lr_options = lr_scheduler_cfg.pop('options', {})
from modelscope.models.nlp.plug.AnnealingLR import AnnealingLR
num_iters = self.max_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=optimizer_cfg.lr,
warmup_iter=lr_scheduler_cfg.warmup * num_iters,
num_iters=num_iters,
decay_style=lr_scheduler_cfg.decay_style,
last_iter=-1)

self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
return self.optimizer, self.lr_scheduler, optim_options, lr_options

def _get_masks_and_position_ids(self, data, eod_token):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()

# Attention mask (lower triangular).
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length),
device=data.device)).view(att_mask_batch, 1, seq_length,
seq_length)

# Loss mask.
loss_mask = torch.ones(
data.size(), dtype=torch.float, device=data.device)
loss_mask[data == eod_token] = 0.0

# Position ids.
position_ids = torch.arange(
seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
return attention_mask, loss_mask, position_ids

def train_step(self, model, inputs):
self._mode = ModeKeys.TRAIN
# format inputs
checkpoint_activations = getattr(self.cfg.train,
'checkpoint_activations', True)
tgt_tokens = inputs['labels'][:, :-1].contiguous()
tgt_labels = inputs['labels'][:, 1:].contiguous()
tgt_attention_mask, dec_loss_mask, position_ids = self._get_masks_and_position_ids(
tgt_tokens, 0)
if getattr(self.cfg.train, 'fp16', None):
tgt_attention_mask = tgt_attention_mask.half()

# forward step
_, output = model(
inputs['input_ids'],
None,
inputs['attention_mask'],
tgt_tokens,
position_ids,
tgt_attention_mask,
checkpoint_activations=checkpoint_activations)

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
tgt_labels)
dec_loss_mask = dec_loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * dec_loss_mask) / dec_loss_mask.sum()

# add model output info to log
self.train_outputs = {'loss': loss}
self.log_buffer.update(self.train_outputs)

def evaluation_step(self, data):
# wapper 1: DeepspeedEngine, wapper 2: DDP
model = self.model.module.module
model.eval()

# model: fp16 wapper; model.module : distributedPlug
vocab_size = model.module.config.original_vocab_size
batch_size = data['input_ids'].shape[0]
beam_generator = TextGenerator(model,
self.eval_preprocessor.nlp_tokenizer,
None)

with torch.no_grad():
tokens = data['input_ids'].long()
padding_mask = data['attention_mask'].byte()
target_ids = data['labels'].long()
target_labels = target_ids[:, 1:].contiguous()
encoder_inputs = [tokens, None, padding_mask]
result = beam_generator.translate_batch(encoder_inputs)
pred_list = result['predictions']
target_list = target_labels.cpu().numpy().tolist()
result['preds'] = []
data['tgts'] = []
for i in range(batch_size):
pred_ids = pred_list[i][0]
pred_ids[pred_ids > vocab_size - 1] = 100
pred_ids = pred_ids.cpu().numpy().tolist()

gold_string = self.eval_preprocessor.decode(
target_list[i], skip_special_tokens=True)
pred_string = self.eval_preprocessor.decode(
pred_ids, skip_special_tokens=True)
result['preds'].append(pred_string)
data['tgts'].append(gold_string)
return result

+ 5
- 2
modelscope/trainers/trainer.py View File

@@ -845,7 +845,10 @@ class EpochBasedTrainer(BaseTrainer):
batch_size = batch_size_per_gpu
num_workers = workers_per_gpu

if dist and not isinstance(dataset, torch.utils.data.IterableDataset):
if dist and not isinstance(
dataset,
torch.utils.data.IterableDataset) and self.cfg.model.get(
'model_parallel_size', 1) == 1:
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
else:
@@ -935,7 +938,7 @@ class EpochBasedTrainer(BaseTrainer):
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`.

"""
if self._dist:
if self._dist and self.cfg.model.get('model_parallel_size', 1) == 1:
from modelscope.trainers.utils.inference import multi_gpu_test
metric_values = multi_gpu_test(
self,


+ 53
- 0
tests/trainers/test_plug_finetune_text_generation.py View File

@@ -0,0 +1,53 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


def test_trainer_with_model_and_args():

def concat_answer_context(dataset):
dataset['src_txt'] = dataset['answers']['text'][0] + '[SEP]' + dataset[
'context']
return dataset

from datasets import load_dataset
dataset_dict = load_dataset('luozhouyang/dureader', 'robust')

train_dataset = dataset_dict['train'].map(concat_answer_context) \
.rename_columns({'question': 'tgt_txt'}).remove_columns('context') \
.remove_columns('id').remove_columns('answers')
eval_dataset = dataset_dict['validation'].map(concat_answer_context) \
.rename_columns({'question': 'tgt_txt'}).remove_columns('context') \
.remove_columns('id').remove_columns('answers')

tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

model_id = 'damo/nlp_plug_text-generation_27B'

kwargs = dict(
model=model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=tmp_dir)

trainer = build_trainer(
name=Trainers.nlp_plug_trainer, default_args=kwargs)
trainer.train()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank')
test_trainer_with_model_and_args()

Loading…
Cancel
Save