Browse Source

test finetune

master
行嗔 3 years ago
parent
commit
a3aee4bec2
23 changed files with 1661 additions and 54 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +21
    -4
      modelscope/models/multi_modal/ofa/generate/sequence_generator.py
  3. +8
    -2
      modelscope/models/multi_modal/ofa/modeling_ofa.py
  4. +9
    -4
      modelscope/preprocessors/multi_modal.py
  5. +2
    -1
      modelscope/preprocessors/ofa/base.py
  6. +8
    -3
      modelscope/preprocessors/ofa/image_captioning.py
  7. +9
    -3
      modelscope/preprocessors/ofa/image_classification.py
  8. +9
    -3
      modelscope/preprocessors/ofa/summarization.py
  9. +9
    -3
      modelscope/preprocessors/ofa/text_classification.py
  10. +9
    -3
      modelscope/preprocessors/ofa/text_to_image_synthesis.py
  11. +9
    -3
      modelscope/preprocessors/ofa/visual_entailment.py
  12. +9
    -3
      modelscope/preprocessors/ofa/visual_grounding.py
  13. +9
    -3
      modelscope/preprocessors/ofa/visual_question_answering.py
  14. +1
    -0
      modelscope/trainers/multi_modal/ofa/__init__.py
  15. +2
    -0
      modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py
  16. +120
    -0
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  17. +283
    -19
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  18. +14
    -0
      modelscope/utils/multi_modal/fp16/__init__.py
  19. +655
    -0
      modelscope/utils/multi_modal/fp16/fp16.py
  20. +216
    -0
      modelscope/utils/multi_modal/fp16/fp16util.py
  21. +237
    -0
      modelscope/utils/multi_modal/fp16/loss_scaler.py
  22. +1
    -0
      tests/pipelines/test_ofa_tasks.py
  23. +20
    -0
      tests/trainers/test_ofa_trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -164,6 +164,7 @@ class Trainers(object):

# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa_tasks = 'ofa-tasks-trainer'

# cv trainers
image_instance_segmentation = 'image-instance-segmentation'


+ 21
- 4
modelscope/models/multi_modal/ofa/generate/sequence_generator.py View File

@@ -398,10 +398,27 @@ class SequenceGenerator(nn.Module):
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)

if self.repeat_ngram_blocker is not None and step > prefix_tokens.size(
1):
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz,
beam_size, step)
if self.repeat_ngram_blocker is not None:
# process prefix_tokens
p_toks_len = prefix_tokens.ne(self.pad).sum(
dim=1) if prefix_tokens is not None else None
if p_toks_len is not None:
p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat(
1, beam_size).view(-1)
no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size
out_prefix = p_toks_len_beam < (
step + no_repeat_ngram_size - 1)
else:
out_prefix = [True] * bsz * beam_size
ngram_blocker_tokens = tokens[out_prefix]
ngram_blocker_lprobs = lprobs[out_prefix]
ngram_blocker_bsz = out_prefix.sum() // beam_size
lprobs[out_prefix] = self.repeat_ngram_blocker(
tokens=ngram_blocker_tokens,
lprobs=ngram_blocker_lprobs,
bsz=ngram_blocker_bsz,
beam_size=beam_size,
step=step)

# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(


+ 8
- 2
modelscope/models/multi_modal/ofa/modeling_ofa.py View File

@@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
from packaging import version
from torch import Tensor, nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
@@ -40,6 +41,8 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = 'ofa-base'
_CONFIG_FOR_DOC = 'OFAConfig'
_TOKENIZER_FOR_DOC = 'OFATokenizer'
TORCH_VERSION = version.parse(torch.__version__)
TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1')

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@@ -114,8 +117,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance):
"""
coords_h = torch.arange(bucket_size)
coords_w = torch.arange(bucket_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w],
indexing='ij')) # 2, Wh, Ww
if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION:
coords = torch.stack(
torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
else:
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - \
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww


+ 9
- 4
modelscope/preprocessors/multi_modal.py View File

@@ -11,7 +11,7 @@ from modelscope.metainfo import Preprocessors
from modelscope.pipelines.base import Input
from modelscope.preprocessors.image import load_image
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
from .base import Preprocessor
from .builder import PREPROCESSORS
from .ofa import * # noqa
@@ -27,11 +27,16 @@ __all__ = [
Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor)
class OfaPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self,
model_dir: str,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
model_dir (str): model path
mode: preprocessor mode (model mode)
"""
super().__init__(*args, **kwargs)
preprocess_mapping = {
@@ -59,8 +64,8 @@ class OfaPreprocessor(Preprocessor):
model_dir)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg,
model_dir)
self.preprocess = preprocess_mapping[self.cfg.task](
cfg=self.cfg, model_dir=model_dir, mode=mode)
self.keys = input_key_mapping[self.cfg.task]
self.tokenizer = self.preprocess.tokenizer



+ 2
- 1
modelscope/preprocessors/ofa/base.py View File

@@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed

class OfaBasePreprocessor:

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self, cfg, model_dir, mode, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
@@ -21,6 +21,7 @@ class OfaBasePreprocessor:
model_dir (str): model path
"""
self.cfg = cfg
self.mode = mode
self.language = self.cfg.model.get('language', 'en')
if self.language == 'en':
tokenizer = OFATokenizer.from_pretrained(model_dir)


+ 8
- 3
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -12,16 +12,21 @@ from .base import OfaBasePreprocessor

class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaImageCaptioningPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),


+ 9
- 3
modelscope/preprocessors/ofa/image_classification.py View File

@@ -6,21 +6,27 @@ from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaImageClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),


+ 9
- 3
modelscope/preprocessors/ofa/summarization.py View File

@@ -1,21 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaSummarizationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaSummarizationPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(


+ 9
- 3
modelscope/preprocessors/ofa/text_classification.py View File

@@ -1,21 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaTextClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaTextClassificationPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
text1 = ' '.join(


+ 9
- 3
modelscope/preprocessors/ofa/text_to_image_synthesis.py View File

@@ -3,21 +3,27 @@ from typing import Any, Dict

import torch

from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaTextToImageSynthesisPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
self.max_src_length = 64

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


+ 9
- 3
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -6,21 +6,27 @@ from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaVisualEntailmentPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),


+ 9
- 3
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -6,21 +6,27 @@ from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),


+ 9
- 3
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -6,21 +6,27 @@ from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor


class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir, split, *args, **kwargs):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path,
split: data phase
mode: preprocessor mode (model mode)
"""
super(OfaVisualQuestionAnsweringPreprocessor,
self).__init__(cfg, model_dir, split, *args, **kwargs)
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),


+ 1
- 0
modelscope/trainers/multi_modal/ofa/__init__.py View File

@@ -0,0 +1 @@
from .ofa_trainer import OFATrainer

+ 2
- 0
modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py View File

@@ -78,6 +78,8 @@ class OFAFileDataset:
self.lineid_to_offset.append(offset)
self.total_row_count += 1
offset += len(line.encode('utf-8'))
pickle.dump(self.lineid_to_offset,
open('{}.index'.format(self.file_path), 'rb'))
self._compute_start_pos_and_row_count()
print(
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping'


+ 120
- 0
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -0,0 +1,120 @@
import os
from os import path as osp
from typing import Dict, Optional

import torch
import torch.distributed as dist
import transformers
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.preprocessors.multi_modal import OfaPreprocessor
from modelscope.preprocessors.ofa.utils.collate import collate_fn
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import ModeKeys, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import init_dist
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
OFADataset, get_schedule)

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
class OFATrainer(BaseTrainer):

def __init__(self, model: str, *args, **kwargs):
model = Model.from_pretrained(model)
super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION))
self.model_dir = model.model_dir
self.model = model.model
self.device_id = 0
self.total_epoch = self.cfg.train.epoch
self.train_batch_size = self.cfg.train.batch_size
self.val_batch_size = self.cfg.evaluation.batch_size
self.save_dir = self.cfg.train.save_dir
init_dist(launcher='pytorch')
self.train_dataset = OFADataset(
file_path=self.cfg.dataset.train_set,
selected_id_keys=self.cfg.dataset.selected_id_keys,
preprocessor=OfaPreprocessor(
model_dir=self.model_dir, split=ModeKeys.TRAIN),
)
self.val_dataset = OFADataset(
file_path=self.cfg.dataset.valid_set,
selected_id_keys=self.cfg.dataset.selected_id_keys,
preprocessor=OfaPreprocessor(
model_dir=self.model_dir, split=ModeKeys.EVAL),
)
epoch_steps = len(
self.train_dataset) // self.cfg.train.gradient_accumulation_steps
self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion(
self.cfg.train.criterion)

def train(self, *args, **kwargs):
assert dist.is_initialized()

self.model.train()
self.model.to(self.device_id)
ddp_model = torch.nn.parallel.DistributedDataParallel(
self.model, device_ids=[
self.device_id,
])

optimizer = transformers.AdamW(
self.model.parameters(),
lr=self.cfg.train.lr,
weight_decay=self.cfg.train.weight_decay,
correct_bias=False,
)
scheduler_class, scheduler_args = get_schedule(self.cfg.train)
if scheduler_class is not None:
lr_scheduler = scheduler_class(**{'optimizer': optimizer},
**scheduler_args)
else:
lr_scheduler = None
for epoch in range(self.total_epoch):
train_sampler = DistributedSampler(
dataset=self.train_dataset, shuffle=True)
train_sampler.set_epoch(epoch)

train_params = {
'pin_memory': True,
'collate_fn': collate_fn,
'batch_size': self.train_batch_size,
'shuffle': False,
'drop_last': True,
'sampler': train_sampler,
'num_workers': 2,
}

train_loader = DataLoader(self.train_dataset, **train_params)

for idx, batch in enumerate(train_loader, start=1):
model_outputs = ddp_model(**batch)
loss, sample_size, logging_output = self.criterion(
model_outputs, batch)
loss.backward()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
optimizer.step()
optimizer.zero_grad()
if idx % 10 == 0:
logger.info(
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
epoch, idx, len(train_loader), loss.item()))
if dist.get_rank() == 0:
os.makedirs(self.ckpt_dir, exist_ok=True)
torch.save(ddp_model.module.state_dict(),
f'{self.ckpt_dir}/epoch{epoch}.bin')

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
pass

+ 283
- 19
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -2,36 +2,36 @@
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from os import path as osp
import math

import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch.nn.modules.loss import _Loss
from torch.utils.data import Dataset

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.preprocessors.multi_modal import OfaPreprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
from .ofa_file_dataset import OFAFileDataset


class OFADataset(Dataset):

def __init__(self,
model_dir,
file_path,
file_path: str,
preprocessor: OfaPreprocessor,
selected_id_keys: str,
dtypes=None,
separator='\t',
cached_index=False,
split=ModeKeys.TRAIN,
**kwargs):
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
selected_col_ids = self.cfg.dataset.selected_col_ids
selected_col_keys = self.cfg.dataset.selected_col_keys

assert selected_col_ids is not None
assert selected_col_keys is not None
self.selected_col_key_l = selected_col_keys.split(',')
assert len(self.selected_col_key_l) == len(selected_col_ids.split(','))
assert selected_id_keys is not None
selected_col_ids = list()
selected_col_keys = list()
for id_key in selected_id_keys.split(','):
id, key = id_key.split(':')
selected_col_ids.append(id)
selected_col_keys.append(key)

self.dataset = OFAFileDataset(
file_path=file_path,
@@ -39,14 +39,278 @@ class OFADataset(Dataset):
dtypes=dtypes,
separator=separator,
cached_index=cached_index)
self.preprocessor = OfaPreprocessor(model_dir, split)
self.preprocessor = preprocessor

def __len__(self):
return len(self.dataset)

def __getitem__(self, index):
value_l = self.dataset[index]
values = self.dataset[index]
data = dict()
for key, value in zip(self.selected_col_key_l, value_l):
for key, value in zip(self.selected_col_keys, values):
data[key] = value
return self.preprocessor(data)


def construct_rdrop_sample(x):
if isinstance(x, dict):
for key in x:
x[key] = construct_rdrop_sample(x[key])
return x
elif isinstance(x, torch.Tensor):
return x.repeat(2, *([1] * (x.dim() - 1)))
elif isinstance(x, int):
return x * 2
elif isinstance(x, np.ndarray):
return x.repeat(2)
else:
raise NotImplementedError


def kl_loss(p, q):
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
loss = (p_loss + q_loss) / 2
return loss


def label_smoothed_nll_loss(lprobs,
target,
epsilon,
update_num,
reduce=True,
drop_worst_ratio=0.0,
drop_worst_after=0,
use_rdrop=False,
reg_alpha=1.0,
constraint_masks=None,
constraint_start=None,
constraint_end=None):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
if constraint_masks is not None:
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(
dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
elif constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(
range(constraint_start, constraint_end))
smooth_loss = -lprobs[:, constraint_range].sum(
dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
else:
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
if drop_worst_ratio > 0 and update_num > drop_worst_after:
if use_rdrop:
true_batch_size = loss.size(0) // 2
_, indices = torch.topk(
loss[:true_batch_size],
k=int(true_batch_size * (1 - drop_worst_ratio)),
largest=False)
loss = torch.cat([loss[indices], loss[indices + true_batch_size]])
nll_loss = torch.cat(
[nll_loss[indices], nll_loss[indices + true_batch_size]])
lprobs = torch.cat(
[lprobs[indices], lprobs[indices + true_batch_size]])
else:
loss, indices = torch.topk(
loss,
k=int(loss.shape[0] * (1 - drop_worst_ratio)),
largest=False)
nll_loss = nll_loss[indices]
lprobs = lprobs[indices]

ntokens = loss.numel()
nll_loss = nll_loss.sum()
loss = loss.sum()
if use_rdrop:
true_batch_size = lprobs.size(0) // 2
p = lprobs[:true_batch_size]
q = lprobs[true_batch_size:]
if constraint_start is not None and constraint_end is not None:
constraint_range = [0, 1, 2, 3] + list(
range(constraint_start, constraint_end))
p = p[:, constraint_range]
q = q[:, constraint_range]
loss += kl_loss(p, q) * reg_alpha

return loss, nll_loss, ntokens


class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):

def __init__(self, args):
super().__init__()
self.sentence_avg = args.sentence_avg
self.eps = args.label_smoothing
self.ignore_prefix_size = args.ignore_prefix_size
self.ignore_eos = args.ignore_eos
self.report_accuracy = args.report_accuracy
self.drop_worst_ratio = args.drop_worst_ratio
self.drop_worst_after = args.drop_worst_after
self.use_rdrop = args.use_rdrop
self.reg_alpha = args.reg_alpha
self.sample_patch_num = args.sample_patch_num

self.constraint_start = None
self.constraint_end = None
if args.constraint_range is not None:
constraint_start, constraint_end = args.constraint_range.split(',')
self.constraint_start = int(constraint_start)
self.constraint_end = int(constraint_end)
self.padding_idx = args.tokenizer.pad_token_id
self.args = args

def forward(self, output, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.

Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if isinstance(sample, list):
if self.sample_patch_num > 0:
sample[0]['net_input'][
'sample_patch_num'] = self.sample_patch_num
loss_v1, sample_size_v1, logging_output_v1 = self.forward(
output[0], sample[0], update_num, reduce)
loss_v2, sample_size_v2, logging_output_v2 = self.forward(
output[1], sample[1], update_num, reduce)
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
sample_size = 1
logging_output = {
'loss':
loss.data,
'loss_v1':
loss_v1.data,
'loss_v2':
loss_v2.data,
'nll_loss':
logging_output_v1['nll_loss'].data / sample_size_v1
+ logging_output_v2['nll_loss'].data / sample_size_v2,
'ntokens':
logging_output_v1['ntokens'] + logging_output_v2['ntokens'],
'nsentences':
logging_output_v1['nsentences']
+ logging_output_v2['nsentences'],
'sample_size':
1,
'sample_size_v1':
sample_size_v1,
'sample_size_v2':
sample_size_v2,
}
return loss, sample_size, logging_output

if self.use_rdrop:
construct_rdrop_sample(sample)

net_output = output
# model(**sample["net_input"])
loss, nll_loss, ntokens = self.compute_loss(
net_output, sample, update_num, reduce=reduce)
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {
'loss': loss.data,
'nll_loss': nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output

def get_lprobs_and_target(self, net_output, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
'conf'] is not None else 1
constraint_masks = None
if 'constraint_masks' in sample and sample[
'constraint_masks'] is not None:
constraint_masks = sample['constraint_masks']
net_output[0].masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(
net_output[0], dim=-1, dtype=torch.float32) * conf
target = sample['target']
if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
target = target[:, self.ignore_prefix_size:].contiguous()
if constraint_masks is not None:
constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable
if self.ignore_eos:
bsz, seq_len, embed_dim = lprobs.size()
eos_indices = target.eq(self.task.tgt_dict.eos())
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim)
target = target[~eos_indices].reshape(bsz, seq_len - 1)
if constraint_masks is not None:
constraint_masks = constraint_masks[~eos_indices].reshape(
bsz, seq_len - 1, embed_dim)
if constraint_masks is not None:
constraint_masks = constraint_masks.view(-1,
constraint_masks.size(-1))
return lprobs.view(-1,
lprobs.size(-1)), target.view(-1), constraint_masks

def compute_loss(self, net_output, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target(
net_output, sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]
target = target[target != self.padding_idx]
loss, nll_loss, ntokens = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
update_num,
reduce=reduce,
drop_worst_ratio=self.drop_worst_ratio,
drop_worst_after=self.drop_worst_after,
use_rdrop=self.use_rdrop,
reg_alpha=self.reg_alpha,
constraint_masks=constraint_masks,
constraint_start=self.constraint_start,
constraint_end=self.constraint_end)
return loss, nll_loss, ntokens


def get_schedule(args):

if args.schedule == 'const':
scheduler_class = transformers.get_constant_schedule_with_warmup
scheduler_args = {
'num_warmup_steps':
int(args.warmup_proportion * args.num_train_steps)
}
elif args.schedule == 'linear':
scheduler_class = transformers.get_linear_schedule_with_warmup
scheduler_args = {
'num_warmup_steps':
int(args.warmup_proportion * args.num_train_steps),
'num_training_steps': args.num_train_steps
}
elif args.schedule == 'cosine':
scheduler_class = transformers.get_cosine_schedule_with_warmup
scheduler_args = {
'num_warmup_steps':
int(args.warmup_proportion * args.num_train_steps),
'num_training_steps': args.num_train_steps
}
elif args.schedule == 'polynomial_decay':
scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup
scheduler_args = {
'num_warmup_steps':
int(args.warmup_proportion * args.num_train_steps),
'num_training_steps': args.num_train_steps,
'lr_end': args.lr_end
}
else:
raise NotImplementedError

return scheduler_class, scheduler_args

+ 14
- 0
modelscope/utils/multi_modal/fp16/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fp16 import FP16_Module, FP16_Optimizer

+ 655
- 0
modelscope/utils/multi_modal/fp16/fp16.py View File

@@ -0,0 +1,655 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stable version of apex FP16 Optimizer"""
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

from .fp16util import (master_params_to_model_params,
model_grads_to_master_grads)
from .loss_scaler import DynamicLossScaler, LossScaler

FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)


def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn


def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""

def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val

return conversion_helper(val, half_conversion)


def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""

def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val

return conversion_helper(val, float_conversion)


class FP16_Module(nn.Module):

def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())

def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))

def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)

def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)


class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
and changing the call to ``backward``.

Example::

model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...

Example with dynamic loss scaling::

...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.

Args:
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa

``init_optimizer`` is expected to have been constructed in the ordinary way.
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
named to replace ``init_optimizer``, for two reasons:
First, it means that references to the same name
later in the file will not have to change.
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
modify ``init_optimizer``. If you do choose a unique name for the new
:class:`FP16_Optimizer` instance, you should only work with this new instance,
because the preexisting optimizer might no longer behave as expected.

``init_optimizer`` may be any Pytorch optimizer.
It may contain a mixture of fp16 and fp32 parameters organized into any number of
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
ingest these ``param_groups`` and remember them.

Calls to ::

loss.backward()

must be replaced with ::

optimizer.backward(loss)

because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
loss scaling and copies to master gradients.

.. note::
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
are downscaled before being applied. This means that adjusting the loss scale, or using
dynamic loss scaling, should not require retuning the learning rate or any other
hyperparameters.


**Advanced options**

**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
See docstring for :attr:`step`.

**Gradient clipping**: Use :attr:`clip_master_grads`.

**Multiple losses**: If your model accumulates gradients from multiple losses,
this can be made more efficient by supplying ``update_master_grads=False``
to :attr:`backward`. See docstring for :attr:`backward`.

**Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::

print(optimizer.loss_scale)
optimizer.loss_scale = new_loss_scale

For static loss scaling, manually adjusting the loss scale over time is a reasonable
thing to do. During later epochs, gradients may become smaller, and a
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
the loss scale is not recommended.

**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
should still work as intended.
"""

def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=False):
if not torch.cuda.is_available:
raise SystemError('Cannot use fp16 without CUDA.')

self.verbose = verbose

self.optimizer = init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
# init_state_dict = init_optimizer.state_dict()

self.fp16_groups = []
self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.optimizer.param_groups):
self.maybe_print(
'FP16_Optimizer processing param group {}:'.format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
self.maybe_print(
'FP16_Optimizer received torch.cuda.HalfTensor with {}'
.format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
# Copythe model parallel flag.
master_param.model_parallel = param.model_parallel
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.optimizer.state:
self.optimizer.state[
master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
self.maybe_print(
'FP16_Optimizer received torch.cuda.FloatTensor with {}'
.format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be either '
'torch.cuda.FloatTensor or torch.cuda.HalfTensor. '
'Received {}'.format(param.type()))

self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)

# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors:
# self.optimizer.load_state_dict(init_state_dict)

if dynamic_loss_scale:
self.dynamic_loss_scale = True
if dynamic_loss_args is not None:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
else:
self.loss_scaler = DynamicLossScaler()
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale)

self.overflow = False
self.first_closure_call_this_step = True

self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_

def maybe_print(self, msg):
if self.verbose:
print(msg)

def __getstate__(self):
raise RuntimeError(
'FP16_Optimizer should be serialized using state_dict().')

def __setstate__(self, state):
raise RuntimeError(
'FP16_Optimizer should be deserialized using load_state_dict().')

def zero_grad(self, set_grads_to_None=False):
"""
Zero fp32 and fp16 parameter grads.
"""
# In principle, only the .grad attributes of the model params need to be zeroed,
# because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe:
for group in self.optimizer.param_groups:
for p in group['params']:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

# Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups:
for param in fp16_group:
if set_grads_to_None:
param.grad = None
else:
if param.grad is not None:
param.grad.detach_(
) # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()

def _check_overflow(self):
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
for group in self.fp32_from_fp32_groups:
for param in group:
params.append(param)
self.overflow = self.loss_scaler.has_overflow(params)

def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)

def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(
self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)

def _model_params_to_master_params(self):
for fp16_group, fp32_from_fp16_group in zip(
self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)

# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(
self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)

def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1. / self.loss_scale)

def clip_master_grads(self, max_norm, norm_type=2):
"""
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.

Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.

Returns:
Total norm of the current fp32 gradients (viewed as a single vector).

.. warning::
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa
"""
if not self.overflow:
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
else:
return -1

def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::

checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict[
'first_closure_call_this_step'] = self.first_closure_call_this_step
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
return state_dict

def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.

Example::

model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.first_closure_call_this_step = state_dict[
'first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups,
state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)

def step(self, closure=None): # could add clip option.
"""
If no closure is supplied, :attr:`step` should be called after
``fp16_optimizer_obj.backward(loss)``.
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
another forward pass using their model.

If a closure is supplied, :attr:`step` may be called without a prior call to
:attr:`backward(loss)`.
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
However, the user should take care that any ``loss.backward()`` call within the closure
has been replaced by ``fp16_optimizer_obj.backward(loss)``.

Args:
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa

Example with closure::

# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# loss.backward() becomes:
optimizer.backward(loss)
return loss
optimizer.step(closure)

.. warning::
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.

.. _`ordinary Pytorch optimizer use`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""

scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)

if self.overflow:
self.maybe_print(
'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}'
.format(scale, self.loss_scale))
return

if closure is not None:
retval = self._step_with_closure(closure)
else:
retval = self.optimizer.step()

self._master_params_to_model_params()

return retval

def _step_with_closure(self, closure):

def wrapped_closure():
# helpful for debugging
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
# .format(self.first_closure_call_this_step))
if self.first_closure_call_this_step:
# We expect that the fp16 params are initially fresh on entering self.step(),
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
# is called within self.optimizer.step().
self.first_closure_call_this_step = False
else:
# If self.optimizer.step() internally calls wrapped_closure more than once,
# it may update the fp32 params after each call. However, self.optimizer
# doesn't know about the fp16 params at all. If the fp32 params get updated,
# we can't rely on self.optimizer to refresh the fp16 params. We need
# to handle that manually:
self._master_params_to_model_params()
# Our API expects the user to give us ownership of the backward() call by
# replacing all calls to loss.backward() with optimizer.backward(loss).
# This requirement holds whether or not the call to backward() is made within a closure.
# If the user is properly calling optimizer.backward(loss) within "closure,"
# calling closure() here will give the fp32 master params fresh gradients
# for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss.
temp_loss = closure()
while (self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
self.maybe_print(
'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, '
'reducing to {}'.format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss

retval = self.optimizer.step(wrapped_closure)

self.first_closure_call_this_step = True

return retval

def backward(self, loss, update_master_grads=True, retain_graph=False):
"""
:attr:`backward` performs the following conceptual steps:

1. fp32_loss = loss.float() (see first Note below)
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa
5. Finally, master grads are divided by loss_scale.

In this way, after :attr:`backward`, the master params have fresh gradients,
and :attr:`step` may be called.

.. note::
:attr:`backward` internally converts the loss to fp32 before applying the loss scale.
This provides some additional safety against overflow if the user has supplied an
fp16 loss value.
However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
:attr:`backward`.

.. warning::
The gradients found in a model's leaves after the call to
:attr:`backward` should not be regarded as valid in general,
because it's possible
they have been scaled (and in the case of dynamic loss scaling,
the scale factor may change over time).
If the user wants to inspect gradients after a call to :attr:`backward`,
only the master gradients should be regarded as valid. These can be retrieved via
:attr:`inspect_master_grad_data()`.

Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa

Example::

# Ordinary operation:
optimizer.backward(loss)

# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)

# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_master_grads=False)
optimizer.backward(loss2, update_master_grads=False)
optimizer.update_master_grads()
"""
# To consider: try multiple backward passes using retain_grad=True to find
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads:
self.update_master_grads()

def update_master_grads(self):
"""
Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly
updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return # noqa
self._model_grads_to_master_grads()
self._downscale_master()

def inspect_master_grad_data(self):
"""
When running with :class:`FP16_Optimizer`,
``.grad`` attributes of a model's fp16 leaves should not be
regarded as truthful, because they might be scaled.
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
the fp32 master params' ``.grad``
attributes will contain valid gradients properly divided by the loss scale. However,
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
nonintuitive. :attr:`inspect_master_grad_data`
allows those gradients to be viewed with shapes corresponding to their associated model leaves.

Returns:
List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
"""
if self.overflow:
print(
'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. '
'Gradients are currently invalid (may be inf, nan, or stale). Returning None.'
)
return None
else:
# The optimizer owns only references to master params.
master_grads_data = []
for param_group in self.optimizer.param_groups:
master_grads_this_group = []
for param in param_group['params']:
if param.grad is not None:
master_grads_this_group.append(param.grad.data)
else:
master_grads_this_group.append(None)
master_grads_data.append(master_grads_this_group)
return master_grads_data

# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale

def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value

loss_scale = property(_get_loss_scale, _set_loss_scale)

# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state

def _set_state(self, value):
self.optimizer.state = value

state = property(_get_state, _set_state)

# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups

def _set_param_groups(self, value):
self.optimizer.param_groups = value

param_groups = property(_get_param_groups, _set_param_groups)

+ 216
- 0
modelscope/utils/multi_modal/fp16/fp16util.py View File

@@ -0,0 +1,216 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable


class tofp16(nn.Module):
"""
Utility module that implements::

def forward(self, input):
return input.half()
"""

def __init__(self):
super(tofp16, self).__init__()

def forward(self, input):
return input.half()


def BN_convert_float(module):
"""
Utility function for network_to_half().

Retained for legacy purposes.
"""
if isinstance(
module,
torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float()
for child in module.children():
BN_convert_float(child)
return module


def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.

Retained for legacy purposes. It is recommended to use FP16Model.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))


def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)

for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)


def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm
) and module.affine is True:
continue
convert_module(module, dtype)
return network


class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""

def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)

def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)


def backwards_debug_hook(grad):
raise RuntimeError(
'master_params recieved a gradient in the backward pass!')


def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.

Args:
model (torch.nn.Module): Existing Pytorch model
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa
Returns:
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa

Example::

model_params, master_params = prep_param_lists(model)

.. warning::
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa

.. _`Training Neural Networks with Mixed Precision: Real Examples`:
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
"""
model_params = [
param for param in model.parameters() if param.requires_grad
]

if flat_master:
# Give the user some more useful error messages
try:
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors(
[param.data for param in model_params]).float()
except: # noqa
print(
'Error in prep_param_lists: model may contain a mixture of parameters '
'of different types. Use flat_master=False, or use F16_Optimizer.'
)
raise
master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True
# master_params.register_hook(backwards_debug_hook)
if master_params.grad is None:
master_params.grad = master_params.new(*master_params.size())
return model_params, [master_params]
else:
master_params = [
param.clone().float().detach() for param in model_params
]
for param in master_params:
param.requires_grad = True
return model_params, master_params


def model_grads_to_master_grads(model_params,
master_params,
flat_master=False):
"""
Copy model gradients to master gradients.

Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa
"""
if flat_master:
# The flattening may incur one more deep copy than is necessary.
master_params[0].grad.data.copy_(
_flatten_dense_tensors([p.grad.data for p in model_params]))
else:
for model, master in zip(model_params, master_params):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(
master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None


def master_params_to_model_params(model_params,
master_params,
flat_master=False):
"""
Copy master parameters to model parameters.

Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa
"""
if flat_master:
for model, master in zip(
model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master)
else:
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)


# Backward compatibility fixes


def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]


TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

+ 237
- 0
modelscope/utils/multi_modal/fp16/loss_scaler.py View File

@@ -0,0 +1,237 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]


class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.

Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.

Args:
scale (float, optional, default=1.0): The loss scale.
"""

def __init__(self, scale=1):
self.cur_scale = scale

# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False

# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False

def update_scale(self, overflow):
pass

@property
def loss_scale(self):
return self.cur_scale

def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.

Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.

Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa
"""

def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis

# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(
p.grad.data):
return True

return False

def has_overflow(self, params):
overflow = self.has_overflow_serial(params)
overflow_gpu = torch.cuda.ByteTensor([overflow])
overflow = overflow_gpu[0].item()
return bool(overflow)

# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if 'value cannot be converted' not in instance.args[0]:
raise
return True
else:
if cpu_sum == float(
'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False

# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):

if not hasattr(self, 'min_scale'):
self.min_scale = 1
if not hasattr(self, 'delayed_shift'):
self.delayed_shift = 1
if not hasattr(self, 'cur_hysteresis'):
self.cur_hysteresis = 1
if not hasattr(self, 'consecutive_hysteresis'):
self.consecutive_hysteresis = True
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale / self.scale_factor,
self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter
- self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1

@property
def loss_scale(self):
return self.cur_scale

def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)

def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)

w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]

learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()

for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))

# Run backprop
optimizer.zero_grad()
loss.backward()

# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)

# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')

# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)

"""

+ 1
- 0
tests/pipelines/test_ofa_tasks.py View File

@@ -172,6 +172,7 @@ class OfaTasksTest(unittest.TestCase):
ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
image = 'data/test/images/visual_grounding.png'
text = '一个圆头的蓝色宝可梦'
text = '火'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)


+ 20
- 0
tests/trainers/test_ofa_trainer.py View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest

from modelscope.trainers.multi_modal.ofa import OFATrainer
from modelscope.utils.test_utils import test_level


class TestOfaTrainer(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
self.trainer = OFATrainer(model_id)
self.trainer.train()
shutil.rmtree(self.trainer.save_dir)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save