@@ -164,6 +164,7 @@ class Trainers(object): | |||||
# multi-modal trainers | # multi-modal trainers | ||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
ofa_tasks = 'ofa-tasks-trainer' | |||||
# cv trainers | # cv trainers | ||||
image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
@@ -398,10 +398,27 @@ class SequenceGenerator(nn.Module): | |||||
if self.should_set_src_lengths: | if self.should_set_src_lengths: | ||||
self.search.set_src_lengths(src_lengths) | self.search.set_src_lengths(src_lengths) | ||||
if self.repeat_ngram_blocker is not None and step > prefix_tokens.size( | |||||
1): | |||||
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||||
beam_size, step) | |||||
if self.repeat_ngram_blocker is not None: | |||||
# process prefix_tokens | |||||
p_toks_len = prefix_tokens.ne(self.pad).sum( | |||||
dim=1) if prefix_tokens is not None else None | |||||
if p_toks_len is not None: | |||||
p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( | |||||
1, beam_size).view(-1) | |||||
no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size | |||||
out_prefix = p_toks_len_beam < ( | |||||
step + no_repeat_ngram_size - 1) | |||||
else: | |||||
out_prefix = [True] * bsz * beam_size | |||||
ngram_blocker_tokens = tokens[out_prefix] | |||||
ngram_blocker_lprobs = lprobs[out_prefix] | |||||
ngram_blocker_bsz = out_prefix.sum() // beam_size | |||||
lprobs[out_prefix] = self.repeat_ngram_blocker( | |||||
tokens=ngram_blocker_tokens, | |||||
lprobs=ngram_blocker_lprobs, | |||||
bsz=ngram_blocker_bsz, | |||||
beam_size=beam_size, | |||||
step=step) | |||||
# Shape: (batch, cand_size) | # Shape: (batch, cand_size) | ||||
cand_scores, cand_indices, cand_beams = self.search.step( | cand_scores, cand_indices, cand_beams = self.search.step( | ||||
@@ -19,6 +19,7 @@ from dataclasses import dataclass | |||||
from typing import Dict, List, Optional, Tuple | from typing import Dict, List, Optional, Tuple | ||||
import torch | import torch | ||||
from packaging import version | |||||
from torch import Tensor, nn | from torch import Tensor, nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from transformers.activations import ACT2FN | from transformers.activations import ACT2FN | ||||
@@ -40,6 +41,8 @@ logger = logging.get_logger(__name__) | |||||
_CHECKPOINT_FOR_DOC = 'ofa-base' | _CHECKPOINT_FOR_DOC = 'ofa-base' | ||||
_CONFIG_FOR_DOC = 'OFAConfig' | _CONFIG_FOR_DOC = 'OFAConfig' | ||||
_TOKENIZER_FOR_DOC = 'OFATokenizer' | _TOKENIZER_FOR_DOC = 'OFATokenizer' | ||||
TORCH_VERSION = version.parse(torch.__version__) | |||||
TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') | |||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024 | DEFAULT_MAX_SOURCE_POSITIONS = 1024 | ||||
DEFAULT_MAX_TARGET_POSITIONS = 1024 | DEFAULT_MAX_TARGET_POSITIONS = 1024 | ||||
@@ -114,8 +117,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance): | |||||
""" | """ | ||||
coords_h = torch.arange(bucket_size) | coords_h = torch.arange(bucket_size) | ||||
coords_w = torch.arange(bucket_size) | coords_w = torch.arange(bucket_size) | ||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w], | |||||
indexing='ij')) # 2, Wh, Ww | |||||
if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: | |||||
coords = torch.stack( | |||||
torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww | |||||
else: | |||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) | |||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | ||||
relative_coords = coords_flatten[:, :, None] - \ | relative_coords = coords_flatten[:, :, None] - \ | ||||
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | ||||
@@ -11,7 +11,7 @@ from modelscope.metainfo import Preprocessors | |||||
from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
from modelscope.utils.constant import Fields, ModelFile, Tasks | |||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||||
from .base import Preprocessor | from .base import Preprocessor | ||||
from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
from .ofa import * # noqa | from .ofa import * # noqa | ||||
@@ -27,11 +27,16 @@ __all__ = [ | |||||
Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) | Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) | ||||
class OfaPreprocessor(Preprocessor): | class OfaPreprocessor(Preprocessor): | ||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
def __init__(self, | |||||
model_dir: str, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
model_dir (str): model path | model_dir (str): model path | ||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
preprocess_mapping = { | preprocess_mapping = { | ||||
@@ -59,8 +64,8 @@ class OfaPreprocessor(Preprocessor): | |||||
model_dir) | model_dir) | ||||
self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, | |||||
model_dir) | |||||
self.preprocess = preprocess_mapping[self.cfg.task]( | |||||
cfg=self.cfg, model_dir=model_dir, mode=mode) | |||||
self.keys = input_key_mapping[self.cfg.task] | self.keys = input_key_mapping[self.cfg.task] | ||||
self.tokenizer = self.preprocess.tokenizer | self.tokenizer = self.preprocess.tokenizer | ||||
@@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed | |||||
class OfaBasePreprocessor: | class OfaBasePreprocessor: | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, cfg, model_dir, mode, *args, **kwargs): | |||||
"""preprocess the data via the vocab.txt from the `model_dir` path | """preprocess the data via the vocab.txt from the `model_dir` path | ||||
Args: | Args: | ||||
@@ -21,6 +21,7 @@ class OfaBasePreprocessor: | |||||
model_dir (str): model path | model_dir (str): model path | ||||
""" | """ | ||||
self.cfg = cfg | self.cfg = cfg | ||||
self.mode = mode | |||||
self.language = self.cfg.model.get('language', 'en') | self.language = self.cfg.model.get('language', 'en') | ||||
if self.language == 'en': | if self.language == 'en': | ||||
tokenizer = OFATokenizer.from_pretrained(model_dir) | tokenizer = OFATokenizer.from_pretrained(model_dir) | ||||
@@ -12,16 +12,21 @@ from .base import OfaBasePreprocessor | |||||
class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaImageCaptioningPreprocessor, | super(OfaImageCaptioningPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
# Initialize transform | # Initialize transform | ||||
self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
@@ -6,21 +6,27 @@ from PIL import Image | |||||
from torchvision import transforms | from torchvision import transforms | ||||
from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaImageClassificationPreprocessor, | super(OfaImageClassificationPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
# Initialize transform | # Initialize transform | ||||
self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
@@ -1,21 +1,27 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaSummarizationPreprocessor(OfaBasePreprocessor): | class OfaSummarizationPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaSummarizationPreprocessor, | super(OfaSummarizationPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
source = super().pre_caption( | source = super().pre_caption( | ||||
@@ -1,21 +1,27 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaTextClassificationPreprocessor, | super(OfaTextClassificationPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
text1 = ' '.join( | text1 = ' '.join( | ||||
@@ -3,21 +3,27 @@ from typing import Any, Dict | |||||
import torch | import torch | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaTextToImageSynthesisPreprocessor, | super(OfaTextToImageSynthesisPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
self.max_src_length = 64 | self.max_src_length = 64 | ||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
@@ -6,21 +6,27 @@ from PIL import Image | |||||
from torchvision import transforms | from torchvision import transforms | ||||
from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaVisualEntailmentPreprocessor, | super(OfaVisualEntailmentPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
# Initialize transform | # Initialize transform | ||||
self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
@@ -6,21 +6,27 @@ from PIL import Image | |||||
from torchvision import transforms | from torchvision import transforms | ||||
from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaVisualGroundingPreprocessor, | super(OfaVisualGroundingPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
# Initialize transform | # Initialize transform | ||||
self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
@@ -6,21 +6,27 @@ from PIL import Image | |||||
from torchvision import transforms | from torchvision import transforms | ||||
from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
from modelscope.utils.constant import ModeKeys | |||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | ||||
def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||||
def __init__(self, | |||||
cfg, | |||||
model_dir, | |||||
mode=ModeKeys.INFERENCE, | |||||
*args, | |||||
**kwargs): | |||||
"""preprocess the data | """preprocess the data | ||||
Args: | Args: | ||||
cfg(modelscope.utils.config.ConfigDict) : model config | cfg(modelscope.utils.config.ConfigDict) : model config | ||||
model_dir (str): model path, | model_dir (str): model path, | ||||
split: data phase | |||||
mode: preprocessor mode (model mode) | |||||
""" | """ | ||||
super(OfaVisualQuestionAnsweringPreprocessor, | super(OfaVisualQuestionAnsweringPreprocessor, | ||||
self).__init__(cfg, model_dir, split, *args, **kwargs) | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||||
# Initialize transform | # Initialize transform | ||||
self.patch_resize_transform = transforms.Compose([ | self.patch_resize_transform = transforms.Compose([ | ||||
lambda image: image.convert('RGB'), | lambda image: image.convert('RGB'), | ||||
@@ -0,0 +1 @@ | |||||
from .ofa_trainer import OFATrainer |
@@ -78,6 +78,8 @@ class OFAFileDataset: | |||||
self.lineid_to_offset.append(offset) | self.lineid_to_offset.append(offset) | ||||
self.total_row_count += 1 | self.total_row_count += 1 | ||||
offset += len(line.encode('utf-8')) | offset += len(line.encode('utf-8')) | ||||
pickle.dump(self.lineid_to_offset, | |||||
open('{}.index'.format(self.file_path), 'rb')) | |||||
self._compute_start_pos_and_row_count() | self._compute_start_pos_and_row_count() | ||||
print( | print( | ||||
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | ||||
@@ -0,0 +1,120 @@ | |||||
import os | |||||
from os import path as osp | |||||
from typing import Dict, Optional | |||||
import torch | |||||
import torch.distributed as dist | |||||
import transformers | |||||
from torch.utils.data import DataLoader | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.models.base import Model | |||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||||
from modelscope.preprocessors.ofa.utils.collate import collate_fn | |||||
from modelscope.trainers.base import BaseTrainer | |||||
from modelscope.trainers.builder import TRAINERS | |||||
from modelscope.utils.constant import ModeKeys, ModelFile | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.torch_utils import init_dist | |||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
OFADataset, get_schedule) | |||||
logger = get_logger() | |||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||||
class OFATrainer(BaseTrainer): | |||||
def __init__(self, model: str, *args, **kwargs): | |||||
model = Model.from_pretrained(model) | |||||
super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION)) | |||||
self.model_dir = model.model_dir | |||||
self.model = model.model | |||||
self.device_id = 0 | |||||
self.total_epoch = self.cfg.train.epoch | |||||
self.train_batch_size = self.cfg.train.batch_size | |||||
self.val_batch_size = self.cfg.evaluation.batch_size | |||||
self.save_dir = self.cfg.train.save_dir | |||||
init_dist(launcher='pytorch') | |||||
self.train_dataset = OFADataset( | |||||
file_path=self.cfg.dataset.train_set, | |||||
selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
preprocessor=OfaPreprocessor( | |||||
model_dir=self.model_dir, split=ModeKeys.TRAIN), | |||||
) | |||||
self.val_dataset = OFADataset( | |||||
file_path=self.cfg.dataset.valid_set, | |||||
selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
preprocessor=OfaPreprocessor( | |||||
model_dir=self.model_dir, split=ModeKeys.EVAL), | |||||
) | |||||
epoch_steps = len( | |||||
self.train_dataset) // self.cfg.train.gradient_accumulation_steps | |||||
self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch | |||||
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | |||||
self.cfg.train.criterion) | |||||
def train(self, *args, **kwargs): | |||||
assert dist.is_initialized() | |||||
self.model.train() | |||||
self.model.to(self.device_id) | |||||
ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
self.model, device_ids=[ | |||||
self.device_id, | |||||
]) | |||||
optimizer = transformers.AdamW( | |||||
self.model.parameters(), | |||||
lr=self.cfg.train.lr, | |||||
weight_decay=self.cfg.train.weight_decay, | |||||
correct_bias=False, | |||||
) | |||||
scheduler_class, scheduler_args = get_schedule(self.cfg.train) | |||||
if scheduler_class is not None: | |||||
lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | |||||
**scheduler_args) | |||||
else: | |||||
lr_scheduler = None | |||||
for epoch in range(self.total_epoch): | |||||
train_sampler = DistributedSampler( | |||||
dataset=self.train_dataset, shuffle=True) | |||||
train_sampler.set_epoch(epoch) | |||||
train_params = { | |||||
'pin_memory': True, | |||||
'collate_fn': collate_fn, | |||||
'batch_size': self.train_batch_size, | |||||
'shuffle': False, | |||||
'drop_last': True, | |||||
'sampler': train_sampler, | |||||
'num_workers': 2, | |||||
} | |||||
train_loader = DataLoader(self.train_dataset, **train_params) | |||||
for idx, batch in enumerate(train_loader, start=1): | |||||
model_outputs = ddp_model(**batch) | |||||
loss, sample_size, logging_output = self.criterion( | |||||
model_outputs, batch) | |||||
loss.backward() | |||||
optimizer.zero_grad() | |||||
if lr_scheduler is not None: | |||||
lr_scheduler.step() | |||||
optimizer.step() | |||||
optimizer.zero_grad() | |||||
if idx % 10 == 0: | |||||
logger.info( | |||||
'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||||
epoch, idx, len(train_loader), loss.item())) | |||||
if dist.get_rank() == 0: | |||||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
torch.save(ddp_model.module.state_dict(), | |||||
f'{self.ckpt_dir}/epoch{epoch}.bin') | |||||
def evaluate(self, | |||||
checkpoint_path: Optional[str] = None, | |||||
*args, | |||||
**kwargs) -> Dict[str, float]: | |||||
pass |
@@ -2,36 +2,36 @@ | |||||
# All rights reserved. | # All rights reserved. | ||||
# This source code is licensed under the Apache 2.0 license | # This source code is licensed under the Apache 2.0 license | ||||
# found in the LICENSE file in the root directory. | # found in the LICENSE file in the root directory. | ||||
from os import path as osp | |||||
import math | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import transformers | |||||
from torch.nn.modules.loss import _Loss | |||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor | from modelscope.preprocessors.multi_modal import OfaPreprocessor | ||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||||
from .ofa_file_dataset import OFAFileDataset | from .ofa_file_dataset import OFAFileDataset | ||||
class OFADataset(Dataset): | class OFADataset(Dataset): | ||||
def __init__(self, | def __init__(self, | ||||
model_dir, | |||||
file_path, | |||||
file_path: str, | |||||
preprocessor: OfaPreprocessor, | |||||
selected_id_keys: str, | |||||
dtypes=None, | dtypes=None, | ||||
separator='\t', | separator='\t', | ||||
cached_index=False, | cached_index=False, | ||||
split=ModeKeys.TRAIN, | |||||
**kwargs): | **kwargs): | ||||
self.cfg = Config.from_file( | |||||
osp.join(model_dir, ModelFile.CONFIGURATION)) | |||||
selected_col_ids = self.cfg.dataset.selected_col_ids | |||||
selected_col_keys = self.cfg.dataset.selected_col_keys | |||||
assert selected_col_ids is not None | |||||
assert selected_col_keys is not None | |||||
self.selected_col_key_l = selected_col_keys.split(',') | |||||
assert len(self.selected_col_key_l) == len(selected_col_ids.split(',')) | |||||
assert selected_id_keys is not None | |||||
selected_col_ids = list() | |||||
selected_col_keys = list() | |||||
for id_key in selected_id_keys.split(','): | |||||
id, key = id_key.split(':') | |||||
selected_col_ids.append(id) | |||||
selected_col_keys.append(key) | |||||
self.dataset = OFAFileDataset( | self.dataset = OFAFileDataset( | ||||
file_path=file_path, | file_path=file_path, | ||||
@@ -39,14 +39,278 @@ class OFADataset(Dataset): | |||||
dtypes=dtypes, | dtypes=dtypes, | ||||
separator=separator, | separator=separator, | ||||
cached_index=cached_index) | cached_index=cached_index) | ||||
self.preprocessor = OfaPreprocessor(model_dir, split) | |||||
self.preprocessor = preprocessor | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.dataset) | return len(self.dataset) | ||||
def __getitem__(self, index): | def __getitem__(self, index): | ||||
value_l = self.dataset[index] | |||||
values = self.dataset[index] | |||||
data = dict() | data = dict() | ||||
for key, value in zip(self.selected_col_key_l, value_l): | |||||
for key, value in zip(self.selected_col_keys, values): | |||||
data[key] = value | data[key] = value | ||||
return self.preprocessor(data) | return self.preprocessor(data) | ||||
def construct_rdrop_sample(x): | |||||
if isinstance(x, dict): | |||||
for key in x: | |||||
x[key] = construct_rdrop_sample(x[key]) | |||||
return x | |||||
elif isinstance(x, torch.Tensor): | |||||
return x.repeat(2, *([1] * (x.dim() - 1))) | |||||
elif isinstance(x, int): | |||||
return x * 2 | |||||
elif isinstance(x, np.ndarray): | |||||
return x.repeat(2) | |||||
else: | |||||
raise NotImplementedError | |||||
def kl_loss(p, q): | |||||
p_loss = F.kl_div(p, torch.exp(q), reduction='sum') | |||||
q_loss = F.kl_div(q, torch.exp(p), reduction='sum') | |||||
loss = (p_loss + q_loss) / 2 | |||||
return loss | |||||
def label_smoothed_nll_loss(lprobs, | |||||
target, | |||||
epsilon, | |||||
update_num, | |||||
reduce=True, | |||||
drop_worst_ratio=0.0, | |||||
drop_worst_after=0, | |||||
use_rdrop=False, | |||||
reg_alpha=1.0, | |||||
constraint_masks=None, | |||||
constraint_start=None, | |||||
constraint_end=None): | |||||
if target.dim() == lprobs.dim() - 1: | |||||
target = target.unsqueeze(-1) | |||||
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) | |||||
if constraint_masks is not None: | |||||
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( | |||||
dim=-1, keepdim=True).squeeze(-1) | |||||
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) | |||||
elif constraint_start is not None and constraint_end is not None: | |||||
constraint_range = [0, 1, 2, 3] + list( | |||||
range(constraint_start, constraint_end)) | |||||
smooth_loss = -lprobs[:, constraint_range].sum( | |||||
dim=-1, keepdim=True).squeeze(-1) | |||||
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) | |||||
else: | |||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) | |||||
eps_i = epsilon / (lprobs.size(-1) - 1) | |||||
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss | |||||
if drop_worst_ratio > 0 and update_num > drop_worst_after: | |||||
if use_rdrop: | |||||
true_batch_size = loss.size(0) // 2 | |||||
_, indices = torch.topk( | |||||
loss[:true_batch_size], | |||||
k=int(true_batch_size * (1 - drop_worst_ratio)), | |||||
largest=False) | |||||
loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) | |||||
nll_loss = torch.cat( | |||||
[nll_loss[indices], nll_loss[indices + true_batch_size]]) | |||||
lprobs = torch.cat( | |||||
[lprobs[indices], lprobs[indices + true_batch_size]]) | |||||
else: | |||||
loss, indices = torch.topk( | |||||
loss, | |||||
k=int(loss.shape[0] * (1 - drop_worst_ratio)), | |||||
largest=False) | |||||
nll_loss = nll_loss[indices] | |||||
lprobs = lprobs[indices] | |||||
ntokens = loss.numel() | |||||
nll_loss = nll_loss.sum() | |||||
loss = loss.sum() | |||||
if use_rdrop: | |||||
true_batch_size = lprobs.size(0) // 2 | |||||
p = lprobs[:true_batch_size] | |||||
q = lprobs[true_batch_size:] | |||||
if constraint_start is not None and constraint_end is not None: | |||||
constraint_range = [0, 1, 2, 3] + list( | |||||
range(constraint_start, constraint_end)) | |||||
p = p[:, constraint_range] | |||||
q = q[:, constraint_range] | |||||
loss += kl_loss(p, q) * reg_alpha | |||||
return loss, nll_loss, ntokens | |||||
class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
def __init__(self, args): | |||||
super().__init__() | |||||
self.sentence_avg = args.sentence_avg | |||||
self.eps = args.label_smoothing | |||||
self.ignore_prefix_size = args.ignore_prefix_size | |||||
self.ignore_eos = args.ignore_eos | |||||
self.report_accuracy = args.report_accuracy | |||||
self.drop_worst_ratio = args.drop_worst_ratio | |||||
self.drop_worst_after = args.drop_worst_after | |||||
self.use_rdrop = args.use_rdrop | |||||
self.reg_alpha = args.reg_alpha | |||||
self.sample_patch_num = args.sample_patch_num | |||||
self.constraint_start = None | |||||
self.constraint_end = None | |||||
if args.constraint_range is not None: | |||||
constraint_start, constraint_end = args.constraint_range.split(',') | |||||
self.constraint_start = int(constraint_start) | |||||
self.constraint_end = int(constraint_end) | |||||
self.padding_idx = args.tokenizer.pad_token_id | |||||
self.args = args | |||||
def forward(self, output, sample, update_num=0, reduce=True): | |||||
"""Compute the loss for the given sample. | |||||
Returns a tuple with three elements: | |||||
1) the loss | |||||
2) the sample size, which is used as the denominator for the gradient | |||||
3) logging outputs to display while training | |||||
""" | |||||
if isinstance(sample, list): | |||||
if self.sample_patch_num > 0: | |||||
sample[0]['net_input'][ | |||||
'sample_patch_num'] = self.sample_patch_num | |||||
loss_v1, sample_size_v1, logging_output_v1 = self.forward( | |||||
output[0], sample[0], update_num, reduce) | |||||
loss_v2, sample_size_v2, logging_output_v2 = self.forward( | |||||
output[1], sample[1], update_num, reduce) | |||||
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2 | |||||
sample_size = 1 | |||||
logging_output = { | |||||
'loss': | |||||
loss.data, | |||||
'loss_v1': | |||||
loss_v1.data, | |||||
'loss_v2': | |||||
loss_v2.data, | |||||
'nll_loss': | |||||
logging_output_v1['nll_loss'].data / sample_size_v1 | |||||
+ logging_output_v2['nll_loss'].data / sample_size_v2, | |||||
'ntokens': | |||||
logging_output_v1['ntokens'] + logging_output_v2['ntokens'], | |||||
'nsentences': | |||||
logging_output_v1['nsentences'] | |||||
+ logging_output_v2['nsentences'], | |||||
'sample_size': | |||||
1, | |||||
'sample_size_v1': | |||||
sample_size_v1, | |||||
'sample_size_v2': | |||||
sample_size_v2, | |||||
} | |||||
return loss, sample_size, logging_output | |||||
if self.use_rdrop: | |||||
construct_rdrop_sample(sample) | |||||
net_output = output | |||||
# model(**sample["net_input"]) | |||||
loss, nll_loss, ntokens = self.compute_loss( | |||||
net_output, sample, update_num, reduce=reduce) | |||||
sample_size = ( | |||||
sample['target'].size(0) if self.sentence_avg else ntokens) | |||||
logging_output = { | |||||
'loss': loss.data, | |||||
'nll_loss': nll_loss.data, | |||||
'ntokens': sample['ntokens'], | |||||
'nsentences': sample['nsentences'], | |||||
'sample_size': sample_size, | |||||
} | |||||
return loss, sample_size, logging_output | |||||
def get_lprobs_and_target(self, net_output, sample): | |||||
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | |||||
'conf'] is not None else 1 | |||||
constraint_masks = None | |||||
if 'constraint_masks' in sample and sample[ | |||||
'constraint_masks'] is not None: | |||||
constraint_masks = sample['constraint_masks'] | |||||
net_output[0].masked_fill_(~constraint_masks, -math.inf) | |||||
if self.constraint_start is not None and self.constraint_end is not None: | |||||
net_output[0][:, :, 4:self.constraint_start] = -math.inf | |||||
net_output[0][:, :, self.constraint_end:] = -math.inf | |||||
lprobs = F.log_softmax( | |||||
net_output[0], dim=-1, dtype=torch.float32) * conf | |||||
target = sample['target'] | |||||
if self.ignore_prefix_size > 0: | |||||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | |||||
target = target[:, self.ignore_prefix_size:].contiguous() | |||||
if constraint_masks is not None: | |||||
constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable | |||||
if self.ignore_eos: | |||||
bsz, seq_len, embed_dim = lprobs.size() | |||||
eos_indices = target.eq(self.task.tgt_dict.eos()) | |||||
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) | |||||
target = target[~eos_indices].reshape(bsz, seq_len - 1) | |||||
if constraint_masks is not None: | |||||
constraint_masks = constraint_masks[~eos_indices].reshape( | |||||
bsz, seq_len - 1, embed_dim) | |||||
if constraint_masks is not None: | |||||
constraint_masks = constraint_masks.view(-1, | |||||
constraint_masks.size(-1)) | |||||
return lprobs.view(-1, | |||||
lprobs.size(-1)), target.view(-1), constraint_masks | |||||
def compute_loss(self, net_output, sample, update_num, reduce=True): | |||||
lprobs, target, constraint_masks = self.get_lprobs_and_target( | |||||
net_output, sample) | |||||
if constraint_masks is not None: | |||||
constraint_masks = constraint_masks[target != self.padding_idx] | |||||
lprobs = lprobs[target != self.padding_idx] | |||||
target = target[target != self.padding_idx] | |||||
loss, nll_loss, ntokens = label_smoothed_nll_loss( | |||||
lprobs, | |||||
target, | |||||
self.eps, | |||||
update_num, | |||||
reduce=reduce, | |||||
drop_worst_ratio=self.drop_worst_ratio, | |||||
drop_worst_after=self.drop_worst_after, | |||||
use_rdrop=self.use_rdrop, | |||||
reg_alpha=self.reg_alpha, | |||||
constraint_masks=constraint_masks, | |||||
constraint_start=self.constraint_start, | |||||
constraint_end=self.constraint_end) | |||||
return loss, nll_loss, ntokens | |||||
def get_schedule(args): | |||||
if args.schedule == 'const': | |||||
scheduler_class = transformers.get_constant_schedule_with_warmup | |||||
scheduler_args = { | |||||
'num_warmup_steps': | |||||
int(args.warmup_proportion * args.num_train_steps) | |||||
} | |||||
elif args.schedule == 'linear': | |||||
scheduler_class = transformers.get_linear_schedule_with_warmup | |||||
scheduler_args = { | |||||
'num_warmup_steps': | |||||
int(args.warmup_proportion * args.num_train_steps), | |||||
'num_training_steps': args.num_train_steps | |||||
} | |||||
elif args.schedule == 'cosine': | |||||
scheduler_class = transformers.get_cosine_schedule_with_warmup | |||||
scheduler_args = { | |||||
'num_warmup_steps': | |||||
int(args.warmup_proportion * args.num_train_steps), | |||||
'num_training_steps': args.num_train_steps | |||||
} | |||||
elif args.schedule == 'polynomial_decay': | |||||
scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup | |||||
scheduler_args = { | |||||
'num_warmup_steps': | |||||
int(args.warmup_proportion * args.num_train_steps), | |||||
'num_training_steps': args.num_train_steps, | |||||
'lr_end': args.lr_end | |||||
} | |||||
else: | |||||
raise NotImplementedError | |||||
return scheduler_class, scheduler_args |
@@ -0,0 +1,14 @@ | |||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
from .fp16 import FP16_Module, FP16_Optimizer |
@@ -0,0 +1,655 @@ | |||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
"""Stable version of apex FP16 Optimizer""" | |||||
import torch | |||||
from torch import nn | |||||
from torch.autograd import Variable | |||||
from torch.nn.parameter import Parameter | |||||
from .fp16util import (master_params_to_model_params, | |||||
model_grads_to_master_grads) | |||||
from .loss_scaler import DynamicLossScaler, LossScaler | |||||
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||||
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||||
def conversion_helper(val, conversion): | |||||
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||||
if not isinstance(val, (tuple, list)): | |||||
return conversion(val) | |||||
rtn = [conversion_helper(v, conversion) for v in val] | |||||
if isinstance(val, tuple): | |||||
rtn = tuple(rtn) | |||||
return rtn | |||||
def fp32_to_fp16(val): | |||||
"""Convert fp32 `val` to fp16""" | |||||
def half_conversion(val): | |||||
val_typecheck = val | |||||
if isinstance(val_typecheck, (Parameter, Variable)): | |||||
val_typecheck = val.data | |||||
if isinstance(val_typecheck, FLOAT_TYPES): | |||||
val = val.half() | |||||
return val | |||||
return conversion_helper(val, half_conversion) | |||||
def fp16_to_fp32(val): | |||||
"""Convert fp16 `val` to fp32""" | |||||
def float_conversion(val): | |||||
val_typecheck = val | |||||
if isinstance(val_typecheck, (Parameter, Variable)): | |||||
val_typecheck = val.data | |||||
if isinstance(val_typecheck, HALF_TYPES): | |||||
val = val.float() | |||||
return val | |||||
return conversion_helper(val, float_conversion) | |||||
class FP16_Module(nn.Module): | |||||
def __init__(self, module): | |||||
super(FP16_Module, self).__init__() | |||||
self.add_module('module', module.half()) | |||||
def forward(self, *inputs, **kwargs): | |||||
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||||
def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
return self.module.state_dict(destination, prefix, keep_vars) | |||||
def load_state_dict(self, state_dict, strict=True): | |||||
self.module.load_state_dict(state_dict, strict=strict) | |||||
class FP16_Optimizer(object): | |||||
""" | |||||
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, | |||||
and manage static or dynamic loss scaling and master weights in a manner transparent to the user. | |||||
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, | |||||
and changing the call to ``backward``. | |||||
Example:: | |||||
model = torch.nn.Linear(D_in, D_out).cuda().half() | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||||
# Name the FP16_Optimizer instance to replace the existing optimizer | |||||
# (recommended but not required): | |||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||||
... | |||||
# loss.backward() becomes: | |||||
optimizer.backward(loss) | |||||
... | |||||
Example with dynamic loss scaling:: | |||||
... | |||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) | |||||
# optional arg to control dynamic loss scaling behavior | |||||
# dynamic_loss_args={'scale_window' : 500}) | |||||
# Usually, dynamic_loss_args is not necessary. | |||||
Args: | |||||
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa | |||||
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa | |||||
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa | |||||
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa | |||||
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa | |||||
``init_optimizer`` is expected to have been constructed in the ordinary way. | |||||
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be | |||||
named to replace ``init_optimizer``, for two reasons: | |||||
First, it means that references to the same name | |||||
later in the file will not have to change. | |||||
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to | |||||
modify ``init_optimizer``. If you do choose a unique name for the new | |||||
:class:`FP16_Optimizer` instance, you should only work with this new instance, | |||||
because the preexisting optimizer might no longer behave as expected. | |||||
``init_optimizer`` may be any Pytorch optimizer. | |||||
It may contain a mixture of fp16 and fp32 parameters organized into any number of | |||||
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will | |||||
ingest these ``param_groups`` and remember them. | |||||
Calls to :: | |||||
loss.backward() | |||||
must be replaced with :: | |||||
optimizer.backward(loss) | |||||
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement | |||||
loss scaling and copies to master gradients. | |||||
.. note:: | |||||
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients | |||||
are downscaled before being applied. This means that adjusting the loss scale, or using | |||||
dynamic loss scaling, should not require retuning the learning rate or any other | |||||
hyperparameters. | |||||
**Advanced options** | |||||
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. | |||||
See docstring for :attr:`step`. | |||||
**Gradient clipping**: Use :attr:`clip_master_grads`. | |||||
**Multiple losses**: If your model accumulates gradients from multiple losses, | |||||
this can be made more efficient by supplying ``update_master_grads=False`` | |||||
to :attr:`backward`. See docstring for :attr:`backward`. | |||||
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: | |||||
print(optimizer.loss_scale) | |||||
optimizer.loss_scale = new_loss_scale | |||||
For static loss scaling, manually adjusting the loss scale over time is a reasonable | |||||
thing to do. During later epochs, gradients may become smaller, and a | |||||
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss | |||||
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting | |||||
the loss scale is not recommended. | |||||
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in | |||||
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` | |||||
should still work as intended. | |||||
""" | |||||
def __init__(self, | |||||
init_optimizer, | |||||
static_loss_scale=1.0, | |||||
dynamic_loss_scale=False, | |||||
dynamic_loss_args=None, | |||||
verbose=False): | |||||
if not torch.cuda.is_available: | |||||
raise SystemError('Cannot use fp16 without CUDA.') | |||||
self.verbose = verbose | |||||
self.optimizer = init_optimizer | |||||
# init_state_dict sets up an alternative way to cast per-param state tensors. | |||||
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. | |||||
# init_state_dict = init_optimizer.state_dict() | |||||
self.fp16_groups = [] | |||||
self.fp32_from_fp16_groups = [] | |||||
self.fp32_from_fp32_groups = [] | |||||
for i, param_group in enumerate(self.optimizer.param_groups): | |||||
self.maybe_print( | |||||
'FP16_Optimizer processing param group {}:'.format(i)) | |||||
fp16_params_this_group = [] | |||||
fp32_params_this_group = [] | |||||
fp32_from_fp16_params_this_group = [] | |||||
for i, param in enumerate(param_group['params']): | |||||
if param.requires_grad: | |||||
if param.type() == 'torch.cuda.HalfTensor': | |||||
self.maybe_print( | |||||
'FP16_Optimizer received torch.cuda.HalfTensor with {}' | |||||
.format(param.size())) | |||||
fp16_params_this_group.append(param) | |||||
master_param = param.detach().clone().float() | |||||
master_param.requires_grad = True | |||||
# Copythe model parallel flag. | |||||
master_param.model_parallel = param.model_parallel | |||||
param_group['params'][i] = master_param | |||||
fp32_from_fp16_params_this_group.append(master_param) | |||||
# Reset existing state dict key to the new master param. | |||||
# We still need to recast per-param state tensors, if any, to FP32. | |||||
if param in self.optimizer.state: | |||||
self.optimizer.state[ | |||||
master_param] = self.optimizer.state.pop(param) | |||||
elif param.type() == 'torch.cuda.FloatTensor': | |||||
self.maybe_print( | |||||
'FP16_Optimizer received torch.cuda.FloatTensor with {}' | |||||
.format(param.size())) | |||||
fp32_params_this_group.append(param) | |||||
param_group['params'][i] = param | |||||
else: | |||||
raise TypeError( | |||||
'Wrapped parameters must be either ' | |||||
'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' | |||||
'Received {}'.format(param.type())) | |||||
self.fp16_groups.append(fp16_params_this_group) | |||||
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) | |||||
self.fp32_from_fp32_groups.append(fp32_params_this_group) | |||||
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors | |||||
self.optimizer.load_state_dict(self.optimizer.state_dict()) | |||||
# alternative way to cast per-param state tensors: | |||||
# self.optimizer.load_state_dict(init_state_dict) | |||||
if dynamic_loss_scale: | |||||
self.dynamic_loss_scale = True | |||||
if dynamic_loss_args is not None: | |||||
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) | |||||
else: | |||||
self.loss_scaler = DynamicLossScaler() | |||||
else: | |||||
self.dynamic_loss_scale = False | |||||
self.loss_scaler = LossScaler(static_loss_scale) | |||||
self.overflow = False | |||||
self.first_closure_call_this_step = True | |||||
self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ | |||||
def maybe_print(self, msg): | |||||
if self.verbose: | |||||
print(msg) | |||||
def __getstate__(self): | |||||
raise RuntimeError( | |||||
'FP16_Optimizer should be serialized using state_dict().') | |||||
def __setstate__(self, state): | |||||
raise RuntimeError( | |||||
'FP16_Optimizer should be deserialized using load_state_dict().') | |||||
def zero_grad(self, set_grads_to_None=False): | |||||
""" | |||||
Zero fp32 and fp16 parameter grads. | |||||
""" | |||||
# In principle, only the .grad attributes of the model params need to be zeroed, | |||||
# because gradients are copied into the FP32 master params. However, we zero | |||||
# all gradients owned by the optimizer, just to be safe: | |||||
for group in self.optimizer.param_groups: | |||||
for p in group['params']: | |||||
if set_grads_to_None: | |||||
p.grad = None | |||||
else: | |||||
if p.grad is not None: | |||||
p.grad.detach_() | |||||
p.grad.zero_() | |||||
# Zero fp16 gradients owned by the model: | |||||
for fp16_group in self.fp16_groups: | |||||
for param in fp16_group: | |||||
if set_grads_to_None: | |||||
param.grad = None | |||||
else: | |||||
if param.grad is not None: | |||||
param.grad.detach_( | |||||
) # as in torch.optim.optimizer.zero_grad() | |||||
param.grad.zero_() | |||||
def _check_overflow(self): | |||||
params = [] | |||||
for group in self.fp16_groups: | |||||
for param in group: | |||||
params.append(param) | |||||
for group in self.fp32_from_fp32_groups: | |||||
for param in group: | |||||
params.append(param) | |||||
self.overflow = self.loss_scaler.has_overflow(params) | |||||
def _update_scale(self, has_overflow=False): | |||||
self.loss_scaler.update_scale(has_overflow) | |||||
def _master_params_to_model_params(self): | |||||
for fp16_group, fp32_from_fp16_group in zip( | |||||
self.fp16_groups, self.fp32_from_fp16_groups): | |||||
master_params_to_model_params(fp16_group, fp32_from_fp16_group) | |||||
def _model_params_to_master_params(self): | |||||
for fp16_group, fp32_from_fp16_group in zip( | |||||
self.fp16_groups, self.fp32_from_fp16_groups): | |||||
master_params_to_model_params(fp32_from_fp16_group, fp16_group) | |||||
# To consider: Integrate distributed with this wrapper by registering a hook on each variable | |||||
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. | |||||
def _model_grads_to_master_grads(self): | |||||
for fp16_group, fp32_from_fp16_group in zip( | |||||
self.fp16_groups, self.fp32_from_fp16_groups): | |||||
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) | |||||
def _downscale_master(self): | |||||
if self.loss_scale != 1.0: | |||||
for group in self.optimizer.param_groups: | |||||
for param in group['params']: | |||||
if param.grad is not None: | |||||
param.grad.data.mul_(1. / self.loss_scale) | |||||
def clip_master_grads(self, max_norm, norm_type=2): | |||||
""" | |||||
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. | |||||
Args: | |||||
max_norm (float or int): max norm of the gradients | |||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||||
infinity norm. | |||||
Returns: | |||||
Total norm of the current fp32 gradients (viewed as a single vector). | |||||
.. warning:: | |||||
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa | |||||
""" | |||||
if not self.overflow: | |||||
fp32_params = [] | |||||
for param_group in self.optimizer.param_groups: | |||||
for param in param_group['params']: | |||||
fp32_params.append(param) | |||||
return self.clip_grad_norm(fp32_params, max_norm, norm_type) | |||||
else: | |||||
return -1 | |||||
def state_dict(self): | |||||
""" | |||||
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |||||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |||||
of the contained Pytorch optimizer. | |||||
Example:: | |||||
checkpoint = {} | |||||
checkpoint['model'] = model.state_dict() | |||||
checkpoint['optimizer'] = optimizer.state_dict() | |||||
torch.save(checkpoint, "saved.pth") | |||||
""" | |||||
state_dict = {} | |||||
state_dict['loss_scaler'] = self.loss_scaler | |||||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |||||
state_dict['overflow'] = self.overflow | |||||
state_dict[ | |||||
'first_closure_call_this_step'] = self.first_closure_call_this_step | |||||
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |||||
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups | |||||
return state_dict | |||||
def load_state_dict(self, state_dict): | |||||
""" | |||||
Loads a state_dict created by an earlier call to state_dict(). | |||||
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |||||
whose parameters in turn came from ``model``, it is expected that the user | |||||
will call ``model.load_state_dict()`` before | |||||
``fp16_optimizer_instance.load_state_dict()`` is called. | |||||
Example:: | |||||
model = torch.nn.Linear(D_in, D_out).cuda().half() | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||||
... | |||||
checkpoint = torch.load("saved.pth") | |||||
model.load_state_dict(checkpoint['model']) | |||||
optimizer.load_state_dict(checkpoint['optimizer']) | |||||
""" | |||||
# I think it should actually be ok to reload the optimizer before the model. | |||||
self.loss_scaler = state_dict['loss_scaler'] | |||||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |||||
self.overflow = state_dict['overflow'] | |||||
self.first_closure_call_this_step = state_dict[ | |||||
'first_closure_call_this_step'] | |||||
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |||||
# At this point, the optimizer's references to the model's fp32 parameters are up to date. | |||||
# The optimizer's hyperparameters and internal buffers are also up to date. | |||||
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |||||
# out of date. There are two options. | |||||
# 1: Refresh the master params from the model's fp16 params. | |||||
# This requires less storage but incurs precision loss. | |||||
# 2: Save and restore the fp32 master copies separately. | |||||
# We choose option 2. | |||||
# | |||||
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |||||
# of their associated parameters, because it's possible those buffers might not exist yet in | |||||
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |||||
# constructed in the same way as the one whose state_dict we are loading, the same master params | |||||
# are guaranteed to exist, so we can just copy_() from the saved master params. | |||||
for current_group, saved_group in zip(self.fp32_from_fp16_groups, | |||||
state_dict['fp32_from_fp16']): | |||||
for current, saved in zip(current_group, saved_group): | |||||
current.data.copy_(saved.data) | |||||
def step(self, closure=None): # could add clip option. | |||||
""" | |||||
If no closure is supplied, :attr:`step` should be called after | |||||
``fp16_optimizer_obj.backward(loss)``. | |||||
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to | |||||
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params | |||||
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run | |||||
another forward pass using their model. | |||||
If a closure is supplied, :attr:`step` may be called without a prior call to | |||||
:attr:`backward(loss)`. | |||||
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. | |||||
However, the user should take care that any ``loss.backward()`` call within the closure | |||||
has been replaced by ``fp16_optimizer_obj.backward(loss)``. | |||||
Args: | |||||
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa | |||||
Example with closure:: | |||||
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an | |||||
# existing pytorch optimizer. | |||||
for input, target in dataset: | |||||
def closure(): | |||||
optimizer.zero_grad() | |||||
output = model(input) | |||||
loss = loss_fn(output, target) | |||||
# loss.backward() becomes: | |||||
optimizer.backward(loss) | |||||
return loss | |||||
optimizer.step(closure) | |||||
.. warning:: | |||||
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. | |||||
.. _`ordinary Pytorch optimizer use`: | |||||
http://pytorch.org/docs/master/optim.html#optimizer-step-closure | |||||
""" | |||||
scale = self.loss_scaler.loss_scale | |||||
self._update_scale(self.overflow) | |||||
if self.overflow: | |||||
self.maybe_print( | |||||
'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' | |||||
.format(scale, self.loss_scale)) | |||||
return | |||||
if closure is not None: | |||||
retval = self._step_with_closure(closure) | |||||
else: | |||||
retval = self.optimizer.step() | |||||
self._master_params_to_model_params() | |||||
return retval | |||||
def _step_with_closure(self, closure): | |||||
def wrapped_closure(): | |||||
# helpful for debugging | |||||
# print("Calling wrapped_closure, first_closure_call_this_step = {}" | |||||
# .format(self.first_closure_call_this_step)) | |||||
if self.first_closure_call_this_step: | |||||
# We expect that the fp16 params are initially fresh on entering self.step(), | |||||
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure() | |||||
# is called within self.optimizer.step(). | |||||
self.first_closure_call_this_step = False | |||||
else: | |||||
# If self.optimizer.step() internally calls wrapped_closure more than once, | |||||
# it may update the fp32 params after each call. However, self.optimizer | |||||
# doesn't know about the fp16 params at all. If the fp32 params get updated, | |||||
# we can't rely on self.optimizer to refresh the fp16 params. We need | |||||
# to handle that manually: | |||||
self._master_params_to_model_params() | |||||
# Our API expects the user to give us ownership of the backward() call by | |||||
# replacing all calls to loss.backward() with optimizer.backward(loss). | |||||
# This requirement holds whether or not the call to backward() is made within a closure. | |||||
# If the user is properly calling optimizer.backward(loss) within "closure," | |||||
# calling closure() here will give the fp32 master params fresh gradients | |||||
# for the optimizer to play with, so all wrapped_closure needs to do is call | |||||
# closure() and return the loss. | |||||
temp_loss = closure() | |||||
while (self.overflow): | |||||
scale = self.loss_scaler.loss_scale | |||||
self._update_scale(self.overflow) | |||||
self.maybe_print( | |||||
'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' | |||||
'reducing to {}'.format(scale, self.loss_scale)) | |||||
temp_loss = closure() | |||||
return temp_loss | |||||
retval = self.optimizer.step(wrapped_closure) | |||||
self.first_closure_call_this_step = True | |||||
return retval | |||||
def backward(self, loss, update_master_grads=True, retain_graph=False): | |||||
""" | |||||
:attr:`backward` performs the following conceptual steps: | |||||
1. fp32_loss = loss.float() (see first Note below) | |||||
2. scaled_loss = fp32_loss*loss_scale | |||||
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa | |||||
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa | |||||
5. Finally, master grads are divided by loss_scale. | |||||
In this way, after :attr:`backward`, the master params have fresh gradients, | |||||
and :attr:`step` may be called. | |||||
.. note:: | |||||
:attr:`backward` internally converts the loss to fp32 before applying the loss scale. | |||||
This provides some additional safety against overflow if the user has supplied an | |||||
fp16 loss value. | |||||
However, for maximum overflow safety, the user should | |||||
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to | |||||
:attr:`backward`. | |||||
.. warning:: | |||||
The gradients found in a model's leaves after the call to | |||||
:attr:`backward` should not be regarded as valid in general, | |||||
because it's possible | |||||
they have been scaled (and in the case of dynamic loss scaling, | |||||
the scale factor may change over time). | |||||
If the user wants to inspect gradients after a call to :attr:`backward`, | |||||
only the master gradients should be regarded as valid. These can be retrieved via | |||||
:attr:`inspect_master_grad_data()`. | |||||
Args: | |||||
loss: The loss output by the user's model. loss may be either float or half (but see first Note above). | |||||
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa | |||||
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa | |||||
Example:: | |||||
# Ordinary operation: | |||||
optimizer.backward(loss) | |||||
# Naive operation with multiple losses (technically valid, but less efficient): | |||||
# fp32 grads will be correct after the second call, but | |||||
# the first call incurs an unnecessary fp16->fp32 grad copy. | |||||
optimizer.backward(loss1) | |||||
optimizer.backward(loss2) | |||||
# More efficient way to handle multiple losses: | |||||
# The fp16->fp32 grad copy is delayed until fp16 grads from all | |||||
# losses have been accumulated. | |||||
optimizer.backward(loss1, update_master_grads=False) | |||||
optimizer.backward(loss2, update_master_grads=False) | |||||
optimizer.update_master_grads() | |||||
""" | |||||
# To consider: try multiple backward passes using retain_grad=True to find | |||||
# a loss scale that works. After you find a loss scale that works, do a final dummy | |||||
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid | |||||
# discarding the iteration, but probably wouldn't improve overall efficiency. | |||||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) | |||||
if update_master_grads: | |||||
self.update_master_grads() | |||||
def update_master_grads(self): | |||||
""" | |||||
Copy the ``.grad`` attribute from stored references to fp16 parameters to | |||||
the ``.grad`` attribute of the fp32 master parameters that are directly | |||||
updated by the optimizer. :attr:`update_master_grads` only needs to be called if | |||||
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. | |||||
""" | |||||
if self.dynamic_loss_scale: | |||||
self._check_overflow() | |||||
if self.overflow: return # noqa | |||||
self._model_grads_to_master_grads() | |||||
self._downscale_master() | |||||
def inspect_master_grad_data(self): | |||||
""" | |||||
When running with :class:`FP16_Optimizer`, | |||||
``.grad`` attributes of a model's fp16 leaves should not be | |||||
regarded as truthful, because they might be scaled. | |||||
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, | |||||
the fp32 master params' ``.grad`` | |||||
attributes will contain valid gradients properly divided by the loss scale. However, | |||||
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be | |||||
nonintuitive. :attr:`inspect_master_grad_data` | |||||
allows those gradients to be viewed with shapes corresponding to their associated model leaves. | |||||
Returns: | |||||
List of lists (one list for each parameter group). The list for each parameter group | |||||
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. | |||||
""" | |||||
if self.overflow: | |||||
print( | |||||
'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' | |||||
'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' | |||||
) | |||||
return None | |||||
else: | |||||
# The optimizer owns only references to master params. | |||||
master_grads_data = [] | |||||
for param_group in self.optimizer.param_groups: | |||||
master_grads_this_group = [] | |||||
for param in param_group['params']: | |||||
if param.grad is not None: | |||||
master_grads_this_group.append(param.grad.data) | |||||
else: | |||||
master_grads_this_group.append(None) | |||||
master_grads_data.append(master_grads_this_group) | |||||
return master_grads_data | |||||
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" | |||||
def _get_loss_scale(self): | |||||
return self.loss_scaler.loss_scale | |||||
def _set_loss_scale(self, value): | |||||
self.loss_scaler.cur_scale = value | |||||
loss_scale = property(_get_loss_scale, _set_loss_scale) | |||||
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" | |||||
def _get_state(self): | |||||
return self.optimizer.state | |||||
def _set_state(self, value): | |||||
self.optimizer.state = value | |||||
state = property(_get_state, _set_state) | |||||
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" | |||||
# (for example, to adjust the learning rate) | |||||
def _get_param_groups(self): | |||||
return self.optimizer.param_groups | |||||
def _set_param_groups(self, value): | |||||
self.optimizer.param_groups = value | |||||
param_groups = property(_get_param_groups, _set_param_groups) |
@@ -0,0 +1,216 @@ | |||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||||
from torch.autograd import Variable | |||||
class tofp16(nn.Module): | |||||
""" | |||||
Utility module that implements:: | |||||
def forward(self, input): | |||||
return input.half() | |||||
""" | |||||
def __init__(self): | |||||
super(tofp16, self).__init__() | |||||
def forward(self, input): | |||||
return input.half() | |||||
def BN_convert_float(module): | |||||
""" | |||||
Utility function for network_to_half(). | |||||
Retained for legacy purposes. | |||||
""" | |||||
if isinstance( | |||||
module, | |||||
torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: | |||||
module.float() | |||||
for child in module.children(): | |||||
BN_convert_float(child) | |||||
return module | |||||
def network_to_half(network): | |||||
""" | |||||
Convert model to half precision in a batchnorm-safe way. | |||||
Retained for legacy purposes. It is recommended to use FP16Model. | |||||
""" | |||||
return nn.Sequential(tofp16(), BN_convert_float(network.half())) | |||||
def convert_module(module, dtype): | |||||
""" | |||||
Converts a module's immediate parameters and buffers to dtype. | |||||
""" | |||||
for param in module.parameters(recurse=False): | |||||
if param is not None: | |||||
if param.data.dtype.is_floating_point: | |||||
param.data = param.data.to(dtype=dtype) | |||||
if param._grad is not None and param._grad.data.dtype.is_floating_point: | |||||
param._grad.data = param._grad.data.to(dtype=dtype) | |||||
for buf in module.buffers(recurse=False): | |||||
if buf is not None and buf.data.dtype.is_floating_point: | |||||
buf.data = buf.data.to(dtype=dtype) | |||||
def convert_network(network, dtype): | |||||
""" | |||||
Converts a network's parameters and buffers to dtype. | |||||
""" | |||||
for module in network.modules(): | |||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm | |||||
) and module.affine is True: | |||||
continue | |||||
convert_module(module, dtype) | |||||
return network | |||||
class FP16Model(nn.Module): | |||||
""" | |||||
Convert model to half precision in a batchnorm-safe way. | |||||
""" | |||||
def __init__(self, network): | |||||
super(FP16Model, self).__init__() | |||||
self.network = convert_network(network, dtype=torch.half) | |||||
def forward(self, *inputs): | |||||
inputs = tuple(t.half() for t in inputs) | |||||
return self.network(*inputs) | |||||
def backwards_debug_hook(grad): | |||||
raise RuntimeError( | |||||
'master_params recieved a gradient in the backward pass!') | |||||
def prep_param_lists(model, flat_master=False): | |||||
""" | |||||
Creates a list of FP32 master parameters for a given model, as in | |||||
`Training Neural Networks with Mixed Precision: Real Examples`_. | |||||
Args: | |||||
model (torch.nn.Module): Existing Pytorch model | |||||
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa | |||||
Returns: | |||||
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa | |||||
Example:: | |||||
model_params, master_params = prep_param_lists(model) | |||||
.. warning:: | |||||
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa | |||||
.. _`Training Neural Networks with Mixed Precision: Real Examples`: | |||||
http://on-demand.gputechconf.com/gtc/2018/video/S81012/ | |||||
""" | |||||
model_params = [ | |||||
param for param in model.parameters() if param.requires_grad | |||||
] | |||||
if flat_master: | |||||
# Give the user some more useful error messages | |||||
try: | |||||
# flatten_dense_tensors returns a contiguous flat array. | |||||
# http://pytorch.org/docs/master/_modules/torch/_utils.html | |||||
master_params = _flatten_dense_tensors( | |||||
[param.data for param in model_params]).float() | |||||
except: # noqa | |||||
print( | |||||
'Error in prep_param_lists: model may contain a mixture of parameters ' | |||||
'of different types. Use flat_master=False, or use F16_Optimizer.' | |||||
) | |||||
raise | |||||
master_params = torch.nn.Parameter(master_params) | |||||
master_params.requires_grad = True | |||||
# master_params.register_hook(backwards_debug_hook) | |||||
if master_params.grad is None: | |||||
master_params.grad = master_params.new(*master_params.size()) | |||||
return model_params, [master_params] | |||||
else: | |||||
master_params = [ | |||||
param.clone().float().detach() for param in model_params | |||||
] | |||||
for param in master_params: | |||||
param.requires_grad = True | |||||
return model_params, master_params | |||||
def model_grads_to_master_grads(model_params, | |||||
master_params, | |||||
flat_master=False): | |||||
""" | |||||
Copy model gradients to master gradients. | |||||
Args: | |||||
model_params: List of model parameters created by :func:`prep_param_lists`. | |||||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa | |||||
""" | |||||
if flat_master: | |||||
# The flattening may incur one more deep copy than is necessary. | |||||
master_params[0].grad.data.copy_( | |||||
_flatten_dense_tensors([p.grad.data for p in model_params])) | |||||
else: | |||||
for model, master in zip(model_params, master_params): | |||||
if model.grad is not None: | |||||
if master.grad is None: | |||||
master.grad = Variable( | |||||
master.data.new(*master.data.size())) | |||||
master.grad.data.copy_(model.grad.data) | |||||
else: | |||||
master.grad = None | |||||
def master_params_to_model_params(model_params, | |||||
master_params, | |||||
flat_master=False): | |||||
""" | |||||
Copy master parameters to model parameters. | |||||
Args: | |||||
model_params: List of model parameters created by :func:`prep_param_lists`. | |||||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa | |||||
""" | |||||
if flat_master: | |||||
for model, master in zip( | |||||
model_params, | |||||
_unflatten_dense_tensors(master_params[0].data, model_params)): | |||||
model.data.copy_(master) | |||||
else: | |||||
for model, master in zip(model_params, master_params): | |||||
model.data.copy_(master.data) | |||||
# Backward compatibility fixes | |||||
def to_python_float(t): | |||||
if hasattr(t, 'item'): | |||||
return t.item() | |||||
else: | |||||
return t[0] | |||||
TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||||
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
@@ -0,0 +1,237 @@ | |||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import torch | |||||
# item() is a recent addition, so this helps with backward compatibility. | |||||
def to_python_float(t): | |||||
if hasattr(t, 'item'): | |||||
return t.item() | |||||
else: | |||||
return t[0] | |||||
class LossScaler: | |||||
""" | |||||
Class that manages a static loss scale. This class is intended to interact with | |||||
:class:`FP16_Optimizer`, and should not be directly manipulated by the user. | |||||
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to | |||||
:class:`FP16_Optimizer`'s constructor. | |||||
Args: | |||||
scale (float, optional, default=1.0): The loss scale. | |||||
""" | |||||
def __init__(self, scale=1): | |||||
self.cur_scale = scale | |||||
# `params` is a list / generator of torch.Variable | |||||
def has_overflow(self, params): | |||||
return False | |||||
# `x` is a torch.Tensor | |||||
def _has_inf_or_nan(x): | |||||
return False | |||||
def update_scale(self, overflow): | |||||
pass | |||||
@property | |||||
def loss_scale(self): | |||||
return self.cur_scale | |||||
def scale_gradient(self, module, grad_in, grad_out): | |||||
return tuple(self.loss_scale * g for g in grad_in) | |||||
def backward(self, loss, retain_graph=False): | |||||
scaled_loss = loss * self.loss_scale | |||||
scaled_loss.backward(retain_graph=retain_graph) | |||||
class DynamicLossScaler: | |||||
""" | |||||
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` | |||||
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of | |||||
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` | |||||
operates, because the default options can be changed using the | |||||
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. | |||||
Loss scaling is designed to combat the problem of underflowing gradients encountered at long | |||||
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss | |||||
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are | |||||
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has | |||||
occurred. | |||||
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, | |||||
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. | |||||
If a certain number of iterations occur without overflowing gradients detected, | |||||
:class:`DynamicLossScaler` increases the loss scale once more. | |||||
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of | |||||
always using the highest loss scale possible without incurring overflow. | |||||
Args: | |||||
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` | |||||
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa | |||||
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa | |||||
""" | |||||
def __init__(self, | |||||
init_scale=2**32, | |||||
scale_factor=2., | |||||
scale_window=1000, | |||||
min_scale=1, | |||||
delayed_shift=1, | |||||
consecutive_hysteresis=False): | |||||
self.cur_scale = init_scale | |||||
self.cur_iter = 0 | |||||
self.last_overflow_iter = -1 | |||||
self.scale_factor = scale_factor | |||||
self.scale_window = scale_window | |||||
self.min_scale = min_scale | |||||
self.delayed_shift = delayed_shift | |||||
self.cur_hysteresis = delayed_shift | |||||
self.consecutive_hysteresis = consecutive_hysteresis | |||||
# `params` is a list / generator of torch.Variable | |||||
def has_overflow_serial(self, params): | |||||
for p in params: | |||||
if p.grad is not None and DynamicLossScaler._has_inf_or_nan( | |||||
p.grad.data): | |||||
return True | |||||
return False | |||||
def has_overflow(self, params): | |||||
overflow = self.has_overflow_serial(params) | |||||
overflow_gpu = torch.cuda.ByteTensor([overflow]) | |||||
overflow = overflow_gpu[0].item() | |||||
return bool(overflow) | |||||
# `x` is a torch.Tensor | |||||
def _has_inf_or_nan(x): | |||||
try: | |||||
# if x is half, the .float() incurs an additional deep copy, but it's necessary if | |||||
# Pytorch's .sum() creates a one-element tensor of the same type as x | |||||
# (which is true for some recent version of pytorch). | |||||
cpu_sum = float(x.float().sum()) | |||||
# More efficient version that can be used if .sum() returns a Python scalar | |||||
# cpu_sum = float(x.sum()) | |||||
except RuntimeError as instance: | |||||
# We want to check if inst is actually an overflow exception. | |||||
# RuntimeError could come from a different error. | |||||
# If so, we still want the exception to propagate. | |||||
if 'value cannot be converted' not in instance.args[0]: | |||||
raise | |||||
return True | |||||
else: | |||||
if cpu_sum == float( | |||||
'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |||||
return True | |||||
return False | |||||
# `overflow` is boolean indicating whether the gradient overflowed | |||||
def update_scale(self, overflow): | |||||
if not hasattr(self, 'min_scale'): | |||||
self.min_scale = 1 | |||||
if not hasattr(self, 'delayed_shift'): | |||||
self.delayed_shift = 1 | |||||
if not hasattr(self, 'cur_hysteresis'): | |||||
self.cur_hysteresis = 1 | |||||
if not hasattr(self, 'consecutive_hysteresis'): | |||||
self.consecutive_hysteresis = True | |||||
if overflow: | |||||
# self.cur_scale /= self.scale_factor | |||||
if self.delayed_shift == 1 or self.cur_hysteresis == 1: | |||||
self.cur_scale = max(self.cur_scale / self.scale_factor, | |||||
self.min_scale) | |||||
else: | |||||
self.cur_hysteresis -= 1 | |||||
self.last_overflow_iter = self.cur_iter | |||||
else: | |||||
if self.consecutive_hysteresis: | |||||
self.cur_hysteresis = self.delayed_shift | |||||
if (self.cur_iter | |||||
- self.last_overflow_iter) % self.scale_window == 0: | |||||
if not self.consecutive_hysteresis: | |||||
self.cur_hysteresis = self.delayed_shift | |||||
self.cur_scale *= self.scale_factor | |||||
self.cur_iter += 1 | |||||
@property | |||||
def loss_scale(self): | |||||
return self.cur_scale | |||||
def scale_gradient(self, module, grad_in, grad_out): | |||||
return tuple(self.loss_scale * g for g in grad_in) | |||||
def backward(self, loss, retain_graph=False): | |||||
scaled_loss = loss * self.loss_scale | |||||
scaled_loss.backward(retain_graph=retain_graph) | |||||
############################################################## | |||||
# Example usage below here -- assuming it's in a separate file | |||||
############################################################## | |||||
""" | |||||
TO-DO separate out into an example. | |||||
if __name__ == "__main__": | |||||
import torch | |||||
from torch.autograd import Variable | |||||
from dynamic_loss_scaler import DynamicLossScaler | |||||
# N is batch size; D_in is input dimension; | |||||
# H is hidden dimension; D_out is output dimension. | |||||
N, D_in, H, D_out = 64, 1000, 100, 10 | |||||
# Create random Tensors to hold inputs and outputs, and wrap them in Variables. | |||||
x = Variable(torch.randn(N, D_in), requires_grad=False) | |||||
y = Variable(torch.randn(N, D_out), requires_grad=False) | |||||
w1 = Variable(torch.randn(D_in, H), requires_grad=True) | |||||
w2 = Variable(torch.randn(H, D_out), requires_grad=True) | |||||
parameters = [w1, w2] | |||||
learning_rate = 1e-6 | |||||
optimizer = torch.optim.SGD(parameters, lr=learning_rate) | |||||
loss_scaler = DynamicLossScaler() | |||||
for t in range(500): | |||||
y_pred = x.mm(w1).clamp(min=0).mm(w2) | |||||
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale | |||||
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) | |||||
print('Iter {} scaled loss: {}'.format(t, loss.data[0])) | |||||
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) | |||||
# Run backprop | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
# Check for overflow | |||||
has_overflow = DynamicLossScaler.has_overflow(parameters) | |||||
# If no overflow, unscale grad and update as usual | |||||
if not has_overflow: | |||||
for param in parameters: | |||||
param.grad.data.mul_(1. / loss_scaler.loss_scale) | |||||
optimizer.step() | |||||
# Otherwise, don't do anything -- ie, skip iteration | |||||
else: | |||||
print('OVERFLOW!') | |||||
# Update loss scale for next iteration | |||||
loss_scaler.update_scale(has_overflow) | |||||
""" |
@@ -172,6 +172,7 @@ class OfaTasksTest(unittest.TestCase): | |||||
ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | ||||
image = 'data/test/images/visual_grounding.png' | image = 'data/test/images/visual_grounding.png' | ||||
text = '一个圆头的蓝色宝可梦' | text = '一个圆头的蓝色宝可梦' | ||||
text = '火' | |||||
input = {'image': image, 'text': text} | input = {'image': image, 'text': text} | ||||
result = ofa_pipe(input) | result = ofa_pipe(input) | ||||
print(result) | print(result) | ||||
@@ -0,0 +1,20 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import shutil | |||||
import unittest | |||||
from modelscope.trainers.multi_modal.ofa import OFATrainer | |||||
from modelscope.utils.test_utils import test_level | |||||
class TestOfaTrainer(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_trainer(self): | |||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||||
self.trainer = OFATrainer(model_id) | |||||
self.trainer.train() | |||||
shutil.rmtree(self.trainer.save_dir) | |||||
if __name__ == '__main__': | |||||
unittest.main() |