| @@ -164,6 +164,7 @@ class Trainers(object): | |||
| # multi-modal trainers | |||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
| ofa_tasks = 'ofa-tasks-trainer' | |||
| # cv trainers | |||
| image_instance_segmentation = 'image-instance-segmentation' | |||
| @@ -398,10 +398,27 @@ class SequenceGenerator(nn.Module): | |||
| if self.should_set_src_lengths: | |||
| self.search.set_src_lengths(src_lengths) | |||
| if self.repeat_ngram_blocker is not None and step > prefix_tokens.size( | |||
| 1): | |||
| lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||
| beam_size, step) | |||
| if self.repeat_ngram_blocker is not None: | |||
| # process prefix_tokens | |||
| p_toks_len = prefix_tokens.ne(self.pad).sum( | |||
| dim=1) if prefix_tokens is not None else None | |||
| if p_toks_len is not None: | |||
| p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( | |||
| 1, beam_size).view(-1) | |||
| no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size | |||
| out_prefix = p_toks_len_beam < ( | |||
| step + no_repeat_ngram_size - 1) | |||
| else: | |||
| out_prefix = [True] * bsz * beam_size | |||
| ngram_blocker_tokens = tokens[out_prefix] | |||
| ngram_blocker_lprobs = lprobs[out_prefix] | |||
| ngram_blocker_bsz = out_prefix.sum() // beam_size | |||
| lprobs[out_prefix] = self.repeat_ngram_blocker( | |||
| tokens=ngram_blocker_tokens, | |||
| lprobs=ngram_blocker_lprobs, | |||
| bsz=ngram_blocker_bsz, | |||
| beam_size=beam_size, | |||
| step=step) | |||
| # Shape: (batch, cand_size) | |||
| cand_scores, cand_indices, cand_beams = self.search.step( | |||
| @@ -19,6 +19,7 @@ from dataclasses import dataclass | |||
| from typing import Dict, List, Optional, Tuple | |||
| import torch | |||
| from packaging import version | |||
| from torch import Tensor, nn | |||
| from torch.nn import functional as F | |||
| from transformers.activations import ACT2FN | |||
| @@ -40,6 +41,8 @@ logger = logging.get_logger(__name__) | |||
| _CHECKPOINT_FOR_DOC = 'ofa-base' | |||
| _CONFIG_FOR_DOC = 'OFAConfig' | |||
| _TOKENIZER_FOR_DOC = 'OFATokenizer' | |||
| TORCH_VERSION = version.parse(torch.__version__) | |||
| TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') | |||
| DEFAULT_MAX_SOURCE_POSITIONS = 1024 | |||
| DEFAULT_MAX_TARGET_POSITIONS = 1024 | |||
| @@ -114,8 +117,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance): | |||
| """ | |||
| coords_h = torch.arange(bucket_size) | |||
| coords_w = torch.arange(bucket_size) | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w], | |||
| indexing='ij')) # 2, Wh, Ww | |||
| if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: | |||
| coords = torch.stack( | |||
| torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww | |||
| else: | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) | |||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
| relative_coords = coords_flatten[:, :, None] - \ | |||
| coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |||
| @@ -11,7 +11,7 @@ from modelscope.metainfo import Preprocessors | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | |||
| from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| from .ofa import * # noqa | |||
| @@ -27,11 +27,16 @@ __all__ = [ | |||
| Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) | |||
| class OfaPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| def __init__(self, | |||
| model_dir: str, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| preprocess_mapping = { | |||
| @@ -59,8 +64,8 @@ class OfaPreprocessor(Preprocessor): | |||
| model_dir) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, | |||
| model_dir) | |||
| self.preprocess = preprocess_mapping[self.cfg.task]( | |||
| cfg=self.cfg, model_dir=model_dir, mode=mode) | |||
| self.keys = input_key_mapping[self.cfg.task] | |||
| self.tokenizer = self.preprocess.tokenizer | |||
| @@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed | |||
| class OfaBasePreprocessor: | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, cfg, model_dir, mode, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| @@ -21,6 +21,7 @@ class OfaBasePreprocessor: | |||
| model_dir (str): model path | |||
| """ | |||
| self.cfg = cfg | |||
| self.mode = mode | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if self.language == 'en': | |||
| tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| @@ -12,16 +12,21 @@ from .base import OfaBasePreprocessor | |||
| class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaImageCaptioningPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| @@ -6,21 +6,27 @@ from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaImageClassificationPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| @@ -1,21 +1,27 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaSummarizationPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| source = super().pre_caption( | |||
| @@ -1,21 +1,27 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaTextClassificationPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| text1 = ' '.join( | |||
| @@ -3,21 +3,27 @@ from typing import Any, Dict | |||
| import torch | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaTextToImageSynthesisPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| self.max_src_length = 64 | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| @@ -6,21 +6,27 @@ from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaVisualEntailmentPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| @@ -6,21 +6,27 @@ from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaVisualGroundingPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| @@ -6,21 +6,27 @@ from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| def __init__(self, | |||
| cfg, | |||
| model_dir, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| mode: preprocessor mode (model mode) | |||
| """ | |||
| super(OfaVisualQuestionAnsweringPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| @@ -0,0 +1 @@ | |||
| from .ofa_trainer import OFATrainer | |||
| @@ -78,6 +78,8 @@ class OFAFileDataset: | |||
| self.lineid_to_offset.append(offset) | |||
| self.total_row_count += 1 | |||
| offset += len(line.encode('utf-8')) | |||
| pickle.dump(self.lineid_to_offset, | |||
| open('{}.index'.format(self.file_path), 'rb')) | |||
| self._compute_start_pos_and_row_count() | |||
| print( | |||
| 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | |||
| @@ -0,0 +1,120 @@ | |||
| import os | |||
| from os import path as osp | |||
| from typing import Dict, Optional | |||
| import torch | |||
| import torch.distributed as dist | |||
| import transformers | |||
| from torch.utils.data import DataLoader | |||
| from torch.utils.data.distributed import DistributedSampler | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||
| from modelscope.preprocessors.ofa.utils.collate import collate_fn | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.utils.constant import ModeKeys, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import init_dist | |||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||
| OFADataset, get_schedule) | |||
| logger = get_logger() | |||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||
| class OFATrainer(BaseTrainer): | |||
| def __init__(self, model: str, *args, **kwargs): | |||
| model = Model.from_pretrained(model) | |||
| super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
| self.model_dir = model.model_dir | |||
| self.model = model.model | |||
| self.device_id = 0 | |||
| self.total_epoch = self.cfg.train.epoch | |||
| self.train_batch_size = self.cfg.train.batch_size | |||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||
| self.save_dir = self.cfg.train.save_dir | |||
| init_dist(launcher='pytorch') | |||
| self.train_dataset = OFADataset( | |||
| file_path=self.cfg.dataset.train_set, | |||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||
| preprocessor=OfaPreprocessor( | |||
| model_dir=self.model_dir, split=ModeKeys.TRAIN), | |||
| ) | |||
| self.val_dataset = OFADataset( | |||
| file_path=self.cfg.dataset.valid_set, | |||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||
| preprocessor=OfaPreprocessor( | |||
| model_dir=self.model_dir, split=ModeKeys.EVAL), | |||
| ) | |||
| epoch_steps = len( | |||
| self.train_dataset) // self.cfg.train.gradient_accumulation_steps | |||
| self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch | |||
| self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | |||
| self.cfg.train.criterion) | |||
| def train(self, *args, **kwargs): | |||
| assert dist.is_initialized() | |||
| self.model.train() | |||
| self.model.to(self.device_id) | |||
| ddp_model = torch.nn.parallel.DistributedDataParallel( | |||
| self.model, device_ids=[ | |||
| self.device_id, | |||
| ]) | |||
| optimizer = transformers.AdamW( | |||
| self.model.parameters(), | |||
| lr=self.cfg.train.lr, | |||
| weight_decay=self.cfg.train.weight_decay, | |||
| correct_bias=False, | |||
| ) | |||
| scheduler_class, scheduler_args = get_schedule(self.cfg.train) | |||
| if scheduler_class is not None: | |||
| lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | |||
| **scheduler_args) | |||
| else: | |||
| lr_scheduler = None | |||
| for epoch in range(self.total_epoch): | |||
| train_sampler = DistributedSampler( | |||
| dataset=self.train_dataset, shuffle=True) | |||
| train_sampler.set_epoch(epoch) | |||
| train_params = { | |||
| 'pin_memory': True, | |||
| 'collate_fn': collate_fn, | |||
| 'batch_size': self.train_batch_size, | |||
| 'shuffle': False, | |||
| 'drop_last': True, | |||
| 'sampler': train_sampler, | |||
| 'num_workers': 2, | |||
| } | |||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||
| for idx, batch in enumerate(train_loader, start=1): | |||
| model_outputs = ddp_model(**batch) | |||
| loss, sample_size, logging_output = self.criterion( | |||
| model_outputs, batch) | |||
| loss.backward() | |||
| optimizer.zero_grad() | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.step() | |||
| optimizer.step() | |||
| optimizer.zero_grad() | |||
| if idx % 10 == 0: | |||
| logger.info( | |||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||
| epoch, idx, len(train_loader), loss.item())) | |||
| if dist.get_rank() == 0: | |||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||
| torch.save(ddp_model.module.state_dict(), | |||
| f'{self.ckpt_dir}/epoch{epoch}.bin') | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| pass | |||
| @@ -2,36 +2,36 @@ | |||
| # All rights reserved. | |||
| # This source code is licensed under the Apache 2.0 license | |||
| # found in the LICENSE file in the root directory. | |||
| from os import path as osp | |||
| import math | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import transformers | |||
| from torch.nn.modules.loss import _Loss | |||
| from torch.utils.data import Dataset | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||
| from .ofa_file_dataset import OFAFileDataset | |||
| class OFADataset(Dataset): | |||
| def __init__(self, | |||
| model_dir, | |||
| file_path, | |||
| file_path: str, | |||
| preprocessor: OfaPreprocessor, | |||
| selected_id_keys: str, | |||
| dtypes=None, | |||
| separator='\t', | |||
| cached_index=False, | |||
| split=ModeKeys.TRAIN, | |||
| **kwargs): | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| selected_col_ids = self.cfg.dataset.selected_col_ids | |||
| selected_col_keys = self.cfg.dataset.selected_col_keys | |||
| assert selected_col_ids is not None | |||
| assert selected_col_keys is not None | |||
| self.selected_col_key_l = selected_col_keys.split(',') | |||
| assert len(self.selected_col_key_l) == len(selected_col_ids.split(',')) | |||
| assert selected_id_keys is not None | |||
| selected_col_ids = list() | |||
| selected_col_keys = list() | |||
| for id_key in selected_id_keys.split(','): | |||
| id, key = id_key.split(':') | |||
| selected_col_ids.append(id) | |||
| selected_col_keys.append(key) | |||
| self.dataset = OFAFileDataset( | |||
| file_path=file_path, | |||
| @@ -39,14 +39,278 @@ class OFADataset(Dataset): | |||
| dtypes=dtypes, | |||
| separator=separator, | |||
| cached_index=cached_index) | |||
| self.preprocessor = OfaPreprocessor(model_dir, split) | |||
| self.preprocessor = preprocessor | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def __getitem__(self, index): | |||
| value_l = self.dataset[index] | |||
| values = self.dataset[index] | |||
| data = dict() | |||
| for key, value in zip(self.selected_col_key_l, value_l): | |||
| for key, value in zip(self.selected_col_keys, values): | |||
| data[key] = value | |||
| return self.preprocessor(data) | |||
| def construct_rdrop_sample(x): | |||
| if isinstance(x, dict): | |||
| for key in x: | |||
| x[key] = construct_rdrop_sample(x[key]) | |||
| return x | |||
| elif isinstance(x, torch.Tensor): | |||
| return x.repeat(2, *([1] * (x.dim() - 1))) | |||
| elif isinstance(x, int): | |||
| return x * 2 | |||
| elif isinstance(x, np.ndarray): | |||
| return x.repeat(2) | |||
| else: | |||
| raise NotImplementedError | |||
| def kl_loss(p, q): | |||
| p_loss = F.kl_div(p, torch.exp(q), reduction='sum') | |||
| q_loss = F.kl_div(q, torch.exp(p), reduction='sum') | |||
| loss = (p_loss + q_loss) / 2 | |||
| return loss | |||
| def label_smoothed_nll_loss(lprobs, | |||
| target, | |||
| epsilon, | |||
| update_num, | |||
| reduce=True, | |||
| drop_worst_ratio=0.0, | |||
| drop_worst_after=0, | |||
| use_rdrop=False, | |||
| reg_alpha=1.0, | |||
| constraint_masks=None, | |||
| constraint_start=None, | |||
| constraint_end=None): | |||
| if target.dim() == lprobs.dim() - 1: | |||
| target = target.unsqueeze(-1) | |||
| nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) | |||
| if constraint_masks is not None: | |||
| smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( | |||
| dim=-1, keepdim=True).squeeze(-1) | |||
| eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) | |||
| elif constraint_start is not None and constraint_end is not None: | |||
| constraint_range = [0, 1, 2, 3] + list( | |||
| range(constraint_start, constraint_end)) | |||
| smooth_loss = -lprobs[:, constraint_range].sum( | |||
| dim=-1, keepdim=True).squeeze(-1) | |||
| eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) | |||
| else: | |||
| smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) | |||
| eps_i = epsilon / (lprobs.size(-1) - 1) | |||
| loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss | |||
| if drop_worst_ratio > 0 and update_num > drop_worst_after: | |||
| if use_rdrop: | |||
| true_batch_size = loss.size(0) // 2 | |||
| _, indices = torch.topk( | |||
| loss[:true_batch_size], | |||
| k=int(true_batch_size * (1 - drop_worst_ratio)), | |||
| largest=False) | |||
| loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) | |||
| nll_loss = torch.cat( | |||
| [nll_loss[indices], nll_loss[indices + true_batch_size]]) | |||
| lprobs = torch.cat( | |||
| [lprobs[indices], lprobs[indices + true_batch_size]]) | |||
| else: | |||
| loss, indices = torch.topk( | |||
| loss, | |||
| k=int(loss.shape[0] * (1 - drop_worst_ratio)), | |||
| largest=False) | |||
| nll_loss = nll_loss[indices] | |||
| lprobs = lprobs[indices] | |||
| ntokens = loss.numel() | |||
| nll_loss = nll_loss.sum() | |||
| loss = loss.sum() | |||
| if use_rdrop: | |||
| true_batch_size = lprobs.size(0) // 2 | |||
| p = lprobs[:true_batch_size] | |||
| q = lprobs[true_batch_size:] | |||
| if constraint_start is not None and constraint_end is not None: | |||
| constraint_range = [0, 1, 2, 3] + list( | |||
| range(constraint_start, constraint_end)) | |||
| p = p[:, constraint_range] | |||
| q = q[:, constraint_range] | |||
| loss += kl_loss(p, q) * reg_alpha | |||
| return loss, nll_loss, ntokens | |||
| class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| def __init__(self, args): | |||
| super().__init__() | |||
| self.sentence_avg = args.sentence_avg | |||
| self.eps = args.label_smoothing | |||
| self.ignore_prefix_size = args.ignore_prefix_size | |||
| self.ignore_eos = args.ignore_eos | |||
| self.report_accuracy = args.report_accuracy | |||
| self.drop_worst_ratio = args.drop_worst_ratio | |||
| self.drop_worst_after = args.drop_worst_after | |||
| self.use_rdrop = args.use_rdrop | |||
| self.reg_alpha = args.reg_alpha | |||
| self.sample_patch_num = args.sample_patch_num | |||
| self.constraint_start = None | |||
| self.constraint_end = None | |||
| if args.constraint_range is not None: | |||
| constraint_start, constraint_end = args.constraint_range.split(',') | |||
| self.constraint_start = int(constraint_start) | |||
| self.constraint_end = int(constraint_end) | |||
| self.padding_idx = args.tokenizer.pad_token_id | |||
| self.args = args | |||
| def forward(self, output, sample, update_num=0, reduce=True): | |||
| """Compute the loss for the given sample. | |||
| Returns a tuple with three elements: | |||
| 1) the loss | |||
| 2) the sample size, which is used as the denominator for the gradient | |||
| 3) logging outputs to display while training | |||
| """ | |||
| if isinstance(sample, list): | |||
| if self.sample_patch_num > 0: | |||
| sample[0]['net_input'][ | |||
| 'sample_patch_num'] = self.sample_patch_num | |||
| loss_v1, sample_size_v1, logging_output_v1 = self.forward( | |||
| output[0], sample[0], update_num, reduce) | |||
| loss_v2, sample_size_v2, logging_output_v2 = self.forward( | |||
| output[1], sample[1], update_num, reduce) | |||
| loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2 | |||
| sample_size = 1 | |||
| logging_output = { | |||
| 'loss': | |||
| loss.data, | |||
| 'loss_v1': | |||
| loss_v1.data, | |||
| 'loss_v2': | |||
| loss_v2.data, | |||
| 'nll_loss': | |||
| logging_output_v1['nll_loss'].data / sample_size_v1 | |||
| + logging_output_v2['nll_loss'].data / sample_size_v2, | |||
| 'ntokens': | |||
| logging_output_v1['ntokens'] + logging_output_v2['ntokens'], | |||
| 'nsentences': | |||
| logging_output_v1['nsentences'] | |||
| + logging_output_v2['nsentences'], | |||
| 'sample_size': | |||
| 1, | |||
| 'sample_size_v1': | |||
| sample_size_v1, | |||
| 'sample_size_v2': | |||
| sample_size_v2, | |||
| } | |||
| return loss, sample_size, logging_output | |||
| if self.use_rdrop: | |||
| construct_rdrop_sample(sample) | |||
| net_output = output | |||
| # model(**sample["net_input"]) | |||
| loss, nll_loss, ntokens = self.compute_loss( | |||
| net_output, sample, update_num, reduce=reduce) | |||
| sample_size = ( | |||
| sample['target'].size(0) if self.sentence_avg else ntokens) | |||
| logging_output = { | |||
| 'loss': loss.data, | |||
| 'nll_loss': nll_loss.data, | |||
| 'ntokens': sample['ntokens'], | |||
| 'nsentences': sample['nsentences'], | |||
| 'sample_size': sample_size, | |||
| } | |||
| return loss, sample_size, logging_output | |||
| def get_lprobs_and_target(self, net_output, sample): | |||
| conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | |||
| 'conf'] is not None else 1 | |||
| constraint_masks = None | |||
| if 'constraint_masks' in sample and sample[ | |||
| 'constraint_masks'] is not None: | |||
| constraint_masks = sample['constraint_masks'] | |||
| net_output[0].masked_fill_(~constraint_masks, -math.inf) | |||
| if self.constraint_start is not None and self.constraint_end is not None: | |||
| net_output[0][:, :, 4:self.constraint_start] = -math.inf | |||
| net_output[0][:, :, self.constraint_end:] = -math.inf | |||
| lprobs = F.log_softmax( | |||
| net_output[0], dim=-1, dtype=torch.float32) * conf | |||
| target = sample['target'] | |||
| if self.ignore_prefix_size > 0: | |||
| lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | |||
| target = target[:, self.ignore_prefix_size:].contiguous() | |||
| if constraint_masks is not None: | |||
| constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable | |||
| if self.ignore_eos: | |||
| bsz, seq_len, embed_dim = lprobs.size() | |||
| eos_indices = target.eq(self.task.tgt_dict.eos()) | |||
| lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) | |||
| target = target[~eos_indices].reshape(bsz, seq_len - 1) | |||
| if constraint_masks is not None: | |||
| constraint_masks = constraint_masks[~eos_indices].reshape( | |||
| bsz, seq_len - 1, embed_dim) | |||
| if constraint_masks is not None: | |||
| constraint_masks = constraint_masks.view(-1, | |||
| constraint_masks.size(-1)) | |||
| return lprobs.view(-1, | |||
| lprobs.size(-1)), target.view(-1), constraint_masks | |||
| def compute_loss(self, net_output, sample, update_num, reduce=True): | |||
| lprobs, target, constraint_masks = self.get_lprobs_and_target( | |||
| net_output, sample) | |||
| if constraint_masks is not None: | |||
| constraint_masks = constraint_masks[target != self.padding_idx] | |||
| lprobs = lprobs[target != self.padding_idx] | |||
| target = target[target != self.padding_idx] | |||
| loss, nll_loss, ntokens = label_smoothed_nll_loss( | |||
| lprobs, | |||
| target, | |||
| self.eps, | |||
| update_num, | |||
| reduce=reduce, | |||
| drop_worst_ratio=self.drop_worst_ratio, | |||
| drop_worst_after=self.drop_worst_after, | |||
| use_rdrop=self.use_rdrop, | |||
| reg_alpha=self.reg_alpha, | |||
| constraint_masks=constraint_masks, | |||
| constraint_start=self.constraint_start, | |||
| constraint_end=self.constraint_end) | |||
| return loss, nll_loss, ntokens | |||
| def get_schedule(args): | |||
| if args.schedule == 'const': | |||
| scheduler_class = transformers.get_constant_schedule_with_warmup | |||
| scheduler_args = { | |||
| 'num_warmup_steps': | |||
| int(args.warmup_proportion * args.num_train_steps) | |||
| } | |||
| elif args.schedule == 'linear': | |||
| scheduler_class = transformers.get_linear_schedule_with_warmup | |||
| scheduler_args = { | |||
| 'num_warmup_steps': | |||
| int(args.warmup_proportion * args.num_train_steps), | |||
| 'num_training_steps': args.num_train_steps | |||
| } | |||
| elif args.schedule == 'cosine': | |||
| scheduler_class = transformers.get_cosine_schedule_with_warmup | |||
| scheduler_args = { | |||
| 'num_warmup_steps': | |||
| int(args.warmup_proportion * args.num_train_steps), | |||
| 'num_training_steps': args.num_train_steps | |||
| } | |||
| elif args.schedule == 'polynomial_decay': | |||
| scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup | |||
| scheduler_args = { | |||
| 'num_warmup_steps': | |||
| int(args.warmup_proportion * args.num_train_steps), | |||
| 'num_training_steps': args.num_train_steps, | |||
| 'lr_end': args.lr_end | |||
| } | |||
| else: | |||
| raise NotImplementedError | |||
| return scheduler_class, scheduler_args | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .fp16 import FP16_Module, FP16_Optimizer | |||
| @@ -0,0 +1,655 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Stable version of apex FP16 Optimizer""" | |||
| import torch | |||
| from torch import nn | |||
| from torch.autograd import Variable | |||
| from torch.nn.parameter import Parameter | |||
| from .fp16util import (master_params_to_model_params, | |||
| model_grads_to_master_grads) | |||
| from .loss_scaler import DynamicLossScaler, LossScaler | |||
| FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||
| HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||
| def conversion_helper(val, conversion): | |||
| """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||
| if not isinstance(val, (tuple, list)): | |||
| return conversion(val) | |||
| rtn = [conversion_helper(v, conversion) for v in val] | |||
| if isinstance(val, tuple): | |||
| rtn = tuple(rtn) | |||
| return rtn | |||
| def fp32_to_fp16(val): | |||
| """Convert fp32 `val` to fp16""" | |||
| def half_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, FLOAT_TYPES): | |||
| val = val.half() | |||
| return val | |||
| return conversion_helper(val, half_conversion) | |||
| def fp16_to_fp32(val): | |||
| """Convert fp16 `val` to fp32""" | |||
| def float_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, HALF_TYPES): | |||
| val = val.float() | |||
| return val | |||
| return conversion_helper(val, float_conversion) | |||
| class FP16_Module(nn.Module): | |||
| def __init__(self, module): | |||
| super(FP16_Module, self).__init__() | |||
| self.add_module('module', module.half()) | |||
| def forward(self, *inputs, **kwargs): | |||
| return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| return self.module.state_dict(destination, prefix, keep_vars) | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| self.module.load_state_dict(state_dict, strict=strict) | |||
| class FP16_Optimizer(object): | |||
| """ | |||
| :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, | |||
| and manage static or dynamic loss scaling and master weights in a manner transparent to the user. | |||
| For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, | |||
| and changing the call to ``backward``. | |||
| Example:: | |||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
| # Name the FP16_Optimizer instance to replace the existing optimizer | |||
| # (recommended but not required): | |||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
| ... | |||
| # loss.backward() becomes: | |||
| optimizer.backward(loss) | |||
| ... | |||
| Example with dynamic loss scaling:: | |||
| ... | |||
| optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) | |||
| # optional arg to control dynamic loss scaling behavior | |||
| # dynamic_loss_args={'scale_window' : 500}) | |||
| # Usually, dynamic_loss_args is not necessary. | |||
| Args: | |||
| init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa | |||
| static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa | |||
| dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa | |||
| dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa | |||
| verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa | |||
| ``init_optimizer`` is expected to have been constructed in the ordinary way. | |||
| It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be | |||
| named to replace ``init_optimizer``, for two reasons: | |||
| First, it means that references to the same name | |||
| later in the file will not have to change. | |||
| Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to | |||
| modify ``init_optimizer``. If you do choose a unique name for the new | |||
| :class:`FP16_Optimizer` instance, you should only work with this new instance, | |||
| because the preexisting optimizer might no longer behave as expected. | |||
| ``init_optimizer`` may be any Pytorch optimizer. | |||
| It may contain a mixture of fp16 and fp32 parameters organized into any number of | |||
| ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will | |||
| ingest these ``param_groups`` and remember them. | |||
| Calls to :: | |||
| loss.backward() | |||
| must be replaced with :: | |||
| optimizer.backward(loss) | |||
| because :class:`FP16_Optimizer` requires ownership of the backward pass to implement | |||
| loss scaling and copies to master gradients. | |||
| .. note:: | |||
| Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients | |||
| are downscaled before being applied. This means that adjusting the loss scale, or using | |||
| dynamic loss scaling, should not require retuning the learning rate or any other | |||
| hyperparameters. | |||
| **Advanced options** | |||
| **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. | |||
| See docstring for :attr:`step`. | |||
| **Gradient clipping**: Use :attr:`clip_master_grads`. | |||
| **Multiple losses**: If your model accumulates gradients from multiple losses, | |||
| this can be made more efficient by supplying ``update_master_grads=False`` | |||
| to :attr:`backward`. See docstring for :attr:`backward`. | |||
| **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: | |||
| print(optimizer.loss_scale) | |||
| optimizer.loss_scale = new_loss_scale | |||
| For static loss scaling, manually adjusting the loss scale over time is a reasonable | |||
| thing to do. During later epochs, gradients may become smaller, and a | |||
| higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss | |||
| scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting | |||
| the loss scale is not recommended. | |||
| **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in | |||
| Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` | |||
| should still work as intended. | |||
| """ | |||
| def __init__(self, | |||
| init_optimizer, | |||
| static_loss_scale=1.0, | |||
| dynamic_loss_scale=False, | |||
| dynamic_loss_args=None, | |||
| verbose=False): | |||
| if not torch.cuda.is_available: | |||
| raise SystemError('Cannot use fp16 without CUDA.') | |||
| self.verbose = verbose | |||
| self.optimizer = init_optimizer | |||
| # init_state_dict sets up an alternative way to cast per-param state tensors. | |||
| # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. | |||
| # init_state_dict = init_optimizer.state_dict() | |||
| self.fp16_groups = [] | |||
| self.fp32_from_fp16_groups = [] | |||
| self.fp32_from_fp32_groups = [] | |||
| for i, param_group in enumerate(self.optimizer.param_groups): | |||
| self.maybe_print( | |||
| 'FP16_Optimizer processing param group {}:'.format(i)) | |||
| fp16_params_this_group = [] | |||
| fp32_params_this_group = [] | |||
| fp32_from_fp16_params_this_group = [] | |||
| for i, param in enumerate(param_group['params']): | |||
| if param.requires_grad: | |||
| if param.type() == 'torch.cuda.HalfTensor': | |||
| self.maybe_print( | |||
| 'FP16_Optimizer received torch.cuda.HalfTensor with {}' | |||
| .format(param.size())) | |||
| fp16_params_this_group.append(param) | |||
| master_param = param.detach().clone().float() | |||
| master_param.requires_grad = True | |||
| # Copythe model parallel flag. | |||
| master_param.model_parallel = param.model_parallel | |||
| param_group['params'][i] = master_param | |||
| fp32_from_fp16_params_this_group.append(master_param) | |||
| # Reset existing state dict key to the new master param. | |||
| # We still need to recast per-param state tensors, if any, to FP32. | |||
| if param in self.optimizer.state: | |||
| self.optimizer.state[ | |||
| master_param] = self.optimizer.state.pop(param) | |||
| elif param.type() == 'torch.cuda.FloatTensor': | |||
| self.maybe_print( | |||
| 'FP16_Optimizer received torch.cuda.FloatTensor with {}' | |||
| .format(param.size())) | |||
| fp32_params_this_group.append(param) | |||
| param_group['params'][i] = param | |||
| else: | |||
| raise TypeError( | |||
| 'Wrapped parameters must be either ' | |||
| 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' | |||
| 'Received {}'.format(param.type())) | |||
| self.fp16_groups.append(fp16_params_this_group) | |||
| self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) | |||
| self.fp32_from_fp32_groups.append(fp32_params_this_group) | |||
| # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors | |||
| self.optimizer.load_state_dict(self.optimizer.state_dict()) | |||
| # alternative way to cast per-param state tensors: | |||
| # self.optimizer.load_state_dict(init_state_dict) | |||
| if dynamic_loss_scale: | |||
| self.dynamic_loss_scale = True | |||
| if dynamic_loss_args is not None: | |||
| self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) | |||
| else: | |||
| self.loss_scaler = DynamicLossScaler() | |||
| else: | |||
| self.dynamic_loss_scale = False | |||
| self.loss_scaler = LossScaler(static_loss_scale) | |||
| self.overflow = False | |||
| self.first_closure_call_this_step = True | |||
| self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ | |||
| def maybe_print(self, msg): | |||
| if self.verbose: | |||
| print(msg) | |||
| def __getstate__(self): | |||
| raise RuntimeError( | |||
| 'FP16_Optimizer should be serialized using state_dict().') | |||
| def __setstate__(self, state): | |||
| raise RuntimeError( | |||
| 'FP16_Optimizer should be deserialized using load_state_dict().') | |||
| def zero_grad(self, set_grads_to_None=False): | |||
| """ | |||
| Zero fp32 and fp16 parameter grads. | |||
| """ | |||
| # In principle, only the .grad attributes of the model params need to be zeroed, | |||
| # because gradients are copied into the FP32 master params. However, we zero | |||
| # all gradients owned by the optimizer, just to be safe: | |||
| for group in self.optimizer.param_groups: | |||
| for p in group['params']: | |||
| if set_grads_to_None: | |||
| p.grad = None | |||
| else: | |||
| if p.grad is not None: | |||
| p.grad.detach_() | |||
| p.grad.zero_() | |||
| # Zero fp16 gradients owned by the model: | |||
| for fp16_group in self.fp16_groups: | |||
| for param in fp16_group: | |||
| if set_grads_to_None: | |||
| param.grad = None | |||
| else: | |||
| if param.grad is not None: | |||
| param.grad.detach_( | |||
| ) # as in torch.optim.optimizer.zero_grad() | |||
| param.grad.zero_() | |||
| def _check_overflow(self): | |||
| params = [] | |||
| for group in self.fp16_groups: | |||
| for param in group: | |||
| params.append(param) | |||
| for group in self.fp32_from_fp32_groups: | |||
| for param in group: | |||
| params.append(param) | |||
| self.overflow = self.loss_scaler.has_overflow(params) | |||
| def _update_scale(self, has_overflow=False): | |||
| self.loss_scaler.update_scale(has_overflow) | |||
| def _master_params_to_model_params(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| master_params_to_model_params(fp16_group, fp32_from_fp16_group) | |||
| def _model_params_to_master_params(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| master_params_to_model_params(fp32_from_fp16_group, fp16_group) | |||
| # To consider: Integrate distributed with this wrapper by registering a hook on each variable | |||
| # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. | |||
| def _model_grads_to_master_grads(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) | |||
| def _downscale_master(self): | |||
| if self.loss_scale != 1.0: | |||
| for group in self.optimizer.param_groups: | |||
| for param in group['params']: | |||
| if param.grad is not None: | |||
| param.grad.data.mul_(1. / self.loss_scale) | |||
| def clip_master_grads(self, max_norm, norm_type=2): | |||
| """ | |||
| Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. | |||
| Args: | |||
| max_norm (float or int): max norm of the gradients | |||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||
| infinity norm. | |||
| Returns: | |||
| Total norm of the current fp32 gradients (viewed as a single vector). | |||
| .. warning:: | |||
| Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa | |||
| """ | |||
| if not self.overflow: | |||
| fp32_params = [] | |||
| for param_group in self.optimizer.param_groups: | |||
| for param in param_group['params']: | |||
| fp32_params.append(param) | |||
| return self.clip_grad_norm(fp32_params, max_norm, norm_type) | |||
| else: | |||
| return -1 | |||
| def state_dict(self): | |||
| """ | |||
| Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |||
| This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |||
| of the contained Pytorch optimizer. | |||
| Example:: | |||
| checkpoint = {} | |||
| checkpoint['model'] = model.state_dict() | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| torch.save(checkpoint, "saved.pth") | |||
| """ | |||
| state_dict = {} | |||
| state_dict['loss_scaler'] = self.loss_scaler | |||
| state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |||
| state_dict['overflow'] = self.overflow | |||
| state_dict[ | |||
| 'first_closure_call_this_step'] = self.first_closure_call_this_step | |||
| state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |||
| state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups | |||
| return state_dict | |||
| def load_state_dict(self, state_dict): | |||
| """ | |||
| Loads a state_dict created by an earlier call to state_dict(). | |||
| If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |||
| whose parameters in turn came from ``model``, it is expected that the user | |||
| will call ``model.load_state_dict()`` before | |||
| ``fp16_optimizer_instance.load_state_dict()`` is called. | |||
| Example:: | |||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
| ... | |||
| checkpoint = torch.load("saved.pth") | |||
| model.load_state_dict(checkpoint['model']) | |||
| optimizer.load_state_dict(checkpoint['optimizer']) | |||
| """ | |||
| # I think it should actually be ok to reload the optimizer before the model. | |||
| self.loss_scaler = state_dict['loss_scaler'] | |||
| self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |||
| self.overflow = state_dict['overflow'] | |||
| self.first_closure_call_this_step = state_dict[ | |||
| 'first_closure_call_this_step'] | |||
| self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |||
| # At this point, the optimizer's references to the model's fp32 parameters are up to date. | |||
| # The optimizer's hyperparameters and internal buffers are also up to date. | |||
| # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |||
| # out of date. There are two options. | |||
| # 1: Refresh the master params from the model's fp16 params. | |||
| # This requires less storage but incurs precision loss. | |||
| # 2: Save and restore the fp32 master copies separately. | |||
| # We choose option 2. | |||
| # | |||
| # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |||
| # of their associated parameters, because it's possible those buffers might not exist yet in | |||
| # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |||
| # constructed in the same way as the one whose state_dict we are loading, the same master params | |||
| # are guaranteed to exist, so we can just copy_() from the saved master params. | |||
| for current_group, saved_group in zip(self.fp32_from_fp16_groups, | |||
| state_dict['fp32_from_fp16']): | |||
| for current, saved in zip(current_group, saved_group): | |||
| current.data.copy_(saved.data) | |||
| def step(self, closure=None): # could add clip option. | |||
| """ | |||
| If no closure is supplied, :attr:`step` should be called after | |||
| ``fp16_optimizer_obj.backward(loss)``. | |||
| :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to | |||
| :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params | |||
| originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run | |||
| another forward pass using their model. | |||
| If a closure is supplied, :attr:`step` may be called without a prior call to | |||
| :attr:`backward(loss)`. | |||
| This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. | |||
| However, the user should take care that any ``loss.backward()`` call within the closure | |||
| has been replaced by ``fp16_optimizer_obj.backward(loss)``. | |||
| Args: | |||
| closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa | |||
| Example with closure:: | |||
| # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an | |||
| # existing pytorch optimizer. | |||
| for input, target in dataset: | |||
| def closure(): | |||
| optimizer.zero_grad() | |||
| output = model(input) | |||
| loss = loss_fn(output, target) | |||
| # loss.backward() becomes: | |||
| optimizer.backward(loss) | |||
| return loss | |||
| optimizer.step(closure) | |||
| .. warning:: | |||
| Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. | |||
| .. _`ordinary Pytorch optimizer use`: | |||
| http://pytorch.org/docs/master/optim.html#optimizer-step-closure | |||
| """ | |||
| scale = self.loss_scaler.loss_scale | |||
| self._update_scale(self.overflow) | |||
| if self.overflow: | |||
| self.maybe_print( | |||
| 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' | |||
| .format(scale, self.loss_scale)) | |||
| return | |||
| if closure is not None: | |||
| retval = self._step_with_closure(closure) | |||
| else: | |||
| retval = self.optimizer.step() | |||
| self._master_params_to_model_params() | |||
| return retval | |||
| def _step_with_closure(self, closure): | |||
| def wrapped_closure(): | |||
| # helpful for debugging | |||
| # print("Calling wrapped_closure, first_closure_call_this_step = {}" | |||
| # .format(self.first_closure_call_this_step)) | |||
| if self.first_closure_call_this_step: | |||
| # We expect that the fp16 params are initially fresh on entering self.step(), | |||
| # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() | |||
| # is called within self.optimizer.step(). | |||
| self.first_closure_call_this_step = False | |||
| else: | |||
| # If self.optimizer.step() internally calls wrapped_closure more than once, | |||
| # it may update the fp32 params after each call. However, self.optimizer | |||
| # doesn't know about the fp16 params at all. If the fp32 params get updated, | |||
| # we can't rely on self.optimizer to refresh the fp16 params. We need | |||
| # to handle that manually: | |||
| self._master_params_to_model_params() | |||
| # Our API expects the user to give us ownership of the backward() call by | |||
| # replacing all calls to loss.backward() with optimizer.backward(loss). | |||
| # This requirement holds whether or not the call to backward() is made within a closure. | |||
| # If the user is properly calling optimizer.backward(loss) within "closure," | |||
| # calling closure() here will give the fp32 master params fresh gradients | |||
| # for the optimizer to play with, so all wrapped_closure needs to do is call | |||
| # closure() and return the loss. | |||
| temp_loss = closure() | |||
| while (self.overflow): | |||
| scale = self.loss_scaler.loss_scale | |||
| self._update_scale(self.overflow) | |||
| self.maybe_print( | |||
| 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' | |||
| 'reducing to {}'.format(scale, self.loss_scale)) | |||
| temp_loss = closure() | |||
| return temp_loss | |||
| retval = self.optimizer.step(wrapped_closure) | |||
| self.first_closure_call_this_step = True | |||
| return retval | |||
| def backward(self, loss, update_master_grads=True, retain_graph=False): | |||
| """ | |||
| :attr:`backward` performs the following conceptual steps: | |||
| 1. fp32_loss = loss.float() (see first Note below) | |||
| 2. scaled_loss = fp32_loss*loss_scale | |||
| 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa | |||
| 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa | |||
| 5. Finally, master grads are divided by loss_scale. | |||
| In this way, after :attr:`backward`, the master params have fresh gradients, | |||
| and :attr:`step` may be called. | |||
| .. note:: | |||
| :attr:`backward` internally converts the loss to fp32 before applying the loss scale. | |||
| This provides some additional safety against overflow if the user has supplied an | |||
| fp16 loss value. | |||
| However, for maximum overflow safety, the user should | |||
| compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to | |||
| :attr:`backward`. | |||
| .. warning:: | |||
| The gradients found in a model's leaves after the call to | |||
| :attr:`backward` should not be regarded as valid in general, | |||
| because it's possible | |||
| they have been scaled (and in the case of dynamic loss scaling, | |||
| the scale factor may change over time). | |||
| If the user wants to inspect gradients after a call to :attr:`backward`, | |||
| only the master gradients should be regarded as valid. These can be retrieved via | |||
| :attr:`inspect_master_grad_data()`. | |||
| Args: | |||
| loss: The loss output by the user's model. loss may be either float or half (but see first Note above). | |||
| update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa | |||
| retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa | |||
| Example:: | |||
| # Ordinary operation: | |||
| optimizer.backward(loss) | |||
| # Naive operation with multiple losses (technically valid, but less efficient): | |||
| # fp32 grads will be correct after the second call, but | |||
| # the first call incurs an unnecessary fp16->fp32 grad copy. | |||
| optimizer.backward(loss1) | |||
| optimizer.backward(loss2) | |||
| # More efficient way to handle multiple losses: | |||
| # The fp16->fp32 grad copy is delayed until fp16 grads from all | |||
| # losses have been accumulated. | |||
| optimizer.backward(loss1, update_master_grads=False) | |||
| optimizer.backward(loss2, update_master_grads=False) | |||
| optimizer.update_master_grads() | |||
| """ | |||
| # To consider: try multiple backward passes using retain_grad=True to find | |||
| # a loss scale that works. After you find a loss scale that works, do a final dummy | |||
| # backward pass with retain_graph=False to tear down the graph. Doing this would avoid | |||
| # discarding the iteration, but probably wouldn't improve overall efficiency. | |||
| self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) | |||
| if update_master_grads: | |||
| self.update_master_grads() | |||
| def update_master_grads(self): | |||
| """ | |||
| Copy the ``.grad`` attribute from stored references to fp16 parameters to | |||
| the ``.grad`` attribute of the fp32 master parameters that are directly | |||
| updated by the optimizer. :attr:`update_master_grads` only needs to be called if | |||
| ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. | |||
| """ | |||
| if self.dynamic_loss_scale: | |||
| self._check_overflow() | |||
| if self.overflow: return # noqa | |||
| self._model_grads_to_master_grads() | |||
| self._downscale_master() | |||
| def inspect_master_grad_data(self): | |||
| """ | |||
| When running with :class:`FP16_Optimizer`, | |||
| ``.grad`` attributes of a model's fp16 leaves should not be | |||
| regarded as truthful, because they might be scaled. | |||
| After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, | |||
| the fp32 master params' ``.grad`` | |||
| attributes will contain valid gradients properly divided by the loss scale. However, | |||
| because :class:`FP16_Optimizer` flattens some parameters, accessing them may be | |||
| nonintuitive. :attr:`inspect_master_grad_data` | |||
| allows those gradients to be viewed with shapes corresponding to their associated model leaves. | |||
| Returns: | |||
| List of lists (one list for each parameter group). The list for each parameter group | |||
| is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. | |||
| """ | |||
| if self.overflow: | |||
| print( | |||
| 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' | |||
| 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' | |||
| ) | |||
| return None | |||
| else: | |||
| # The optimizer owns only references to master params. | |||
| master_grads_data = [] | |||
| for param_group in self.optimizer.param_groups: | |||
| master_grads_this_group = [] | |||
| for param in param_group['params']: | |||
| if param.grad is not None: | |||
| master_grads_this_group.append(param.grad.data) | |||
| else: | |||
| master_grads_this_group.append(None) | |||
| master_grads_data.append(master_grads_this_group) | |||
| return master_grads_data | |||
| # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" | |||
| def _get_loss_scale(self): | |||
| return self.loss_scaler.loss_scale | |||
| def _set_loss_scale(self, value): | |||
| self.loss_scaler.cur_scale = value | |||
| loss_scale = property(_get_loss_scale, _set_loss_scale) | |||
| # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" | |||
| def _get_state(self): | |||
| return self.optimizer.state | |||
| def _set_state(self, value): | |||
| self.optimizer.state = value | |||
| state = property(_get_state, _set_state) | |||
| # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" | |||
| # (for example, to adjust the learning rate) | |||
| def _get_param_groups(self): | |||
| return self.optimizer.param_groups | |||
| def _set_param_groups(self, value): | |||
| self.optimizer.param_groups = value | |||
| param_groups = property(_get_param_groups, _set_param_groups) | |||
| @@ -0,0 +1,216 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
| from torch.autograd import Variable | |||
| class tofp16(nn.Module): | |||
| """ | |||
| Utility module that implements:: | |||
| def forward(self, input): | |||
| return input.half() | |||
| """ | |||
| def __init__(self): | |||
| super(tofp16, self).__init__() | |||
| def forward(self, input): | |||
| return input.half() | |||
| def BN_convert_float(module): | |||
| """ | |||
| Utility function for network_to_half(). | |||
| Retained for legacy purposes. | |||
| """ | |||
| if isinstance( | |||
| module, | |||
| torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: | |||
| module.float() | |||
| for child in module.children(): | |||
| BN_convert_float(child) | |||
| return module | |||
| def network_to_half(network): | |||
| """ | |||
| Convert model to half precision in a batchnorm-safe way. | |||
| Retained for legacy purposes. It is recommended to use FP16Model. | |||
| """ | |||
| return nn.Sequential(tofp16(), BN_convert_float(network.half())) | |||
| def convert_module(module, dtype): | |||
| """ | |||
| Converts a module's immediate parameters and buffers to dtype. | |||
| """ | |||
| for param in module.parameters(recurse=False): | |||
| if param is not None: | |||
| if param.data.dtype.is_floating_point: | |||
| param.data = param.data.to(dtype=dtype) | |||
| if param._grad is not None and param._grad.data.dtype.is_floating_point: | |||
| param._grad.data = param._grad.data.to(dtype=dtype) | |||
| for buf in module.buffers(recurse=False): | |||
| if buf is not None and buf.data.dtype.is_floating_point: | |||
| buf.data = buf.data.to(dtype=dtype) | |||
| def convert_network(network, dtype): | |||
| """ | |||
| Converts a network's parameters and buffers to dtype. | |||
| """ | |||
| for module in network.modules(): | |||
| if isinstance(module, torch.nn.modules.batchnorm._BatchNorm | |||
| ) and module.affine is True: | |||
| continue | |||
| convert_module(module, dtype) | |||
| return network | |||
| class FP16Model(nn.Module): | |||
| """ | |||
| Convert model to half precision in a batchnorm-safe way. | |||
| """ | |||
| def __init__(self, network): | |||
| super(FP16Model, self).__init__() | |||
| self.network = convert_network(network, dtype=torch.half) | |||
| def forward(self, *inputs): | |||
| inputs = tuple(t.half() for t in inputs) | |||
| return self.network(*inputs) | |||
| def backwards_debug_hook(grad): | |||
| raise RuntimeError( | |||
| 'master_params recieved a gradient in the backward pass!') | |||
| def prep_param_lists(model, flat_master=False): | |||
| """ | |||
| Creates a list of FP32 master parameters for a given model, as in | |||
| `Training Neural Networks with Mixed Precision: Real Examples`_. | |||
| Args: | |||
| model (torch.nn.Module): Existing Pytorch model | |||
| flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa | |||
| Returns: | |||
| A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa | |||
| Example:: | |||
| model_params, master_params = prep_param_lists(model) | |||
| .. warning:: | |||
| Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa | |||
| .. _`Training Neural Networks with Mixed Precision: Real Examples`: | |||
| http://on-demand.gputechconf.com/gtc/2018/video/S81012/ | |||
| """ | |||
| model_params = [ | |||
| param for param in model.parameters() if param.requires_grad | |||
| ] | |||
| if flat_master: | |||
| # Give the user some more useful error messages | |||
| try: | |||
| # flatten_dense_tensors returns a contiguous flat array. | |||
| # http://pytorch.org/docs/master/_modules/torch/_utils.html | |||
| master_params = _flatten_dense_tensors( | |||
| [param.data for param in model_params]).float() | |||
| except: # noqa | |||
| print( | |||
| 'Error in prep_param_lists: model may contain a mixture of parameters ' | |||
| 'of different types. Use flat_master=False, or use F16_Optimizer.' | |||
| ) | |||
| raise | |||
| master_params = torch.nn.Parameter(master_params) | |||
| master_params.requires_grad = True | |||
| # master_params.register_hook(backwards_debug_hook) | |||
| if master_params.grad is None: | |||
| master_params.grad = master_params.new(*master_params.size()) | |||
| return model_params, [master_params] | |||
| else: | |||
| master_params = [ | |||
| param.clone().float().detach() for param in model_params | |||
| ] | |||
| for param in master_params: | |||
| param.requires_grad = True | |||
| return model_params, master_params | |||
| def model_grads_to_master_grads(model_params, | |||
| master_params, | |||
| flat_master=False): | |||
| """ | |||
| Copy model gradients to master gradients. | |||
| Args: | |||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa | |||
| """ | |||
| if flat_master: | |||
| # The flattening may incur one more deep copy than is necessary. | |||
| master_params[0].grad.data.copy_( | |||
| _flatten_dense_tensors([p.grad.data for p in model_params])) | |||
| else: | |||
| for model, master in zip(model_params, master_params): | |||
| if model.grad is not None: | |||
| if master.grad is None: | |||
| master.grad = Variable( | |||
| master.data.new(*master.data.size())) | |||
| master.grad.data.copy_(model.grad.data) | |||
| else: | |||
| master.grad = None | |||
| def master_params_to_model_params(model_params, | |||
| master_params, | |||
| flat_master=False): | |||
| """ | |||
| Copy master parameters to model parameters. | |||
| Args: | |||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa | |||
| """ | |||
| if flat_master: | |||
| for model, master in zip( | |||
| model_params, | |||
| _unflatten_dense_tensors(master_params[0].data, model_params)): | |||
| model.data.copy_(master) | |||
| else: | |||
| for model, master in zip(model_params, master_params): | |||
| model.data.copy_(master.data) | |||
| # Backward compatibility fixes | |||
| def to_python_float(t): | |||
| if hasattr(t, 'item'): | |||
| return t.item() | |||
| else: | |||
| return t[0] | |||
| TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
| TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||
| @@ -0,0 +1,237 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| # item() is a recent addition, so this helps with backward compatibility. | |||
| def to_python_float(t): | |||
| if hasattr(t, 'item'): | |||
| return t.item() | |||
| else: | |||
| return t[0] | |||
| class LossScaler: | |||
| """ | |||
| Class that manages a static loss scale. This class is intended to interact with | |||
| :class:`FP16_Optimizer`, and should not be directly manipulated by the user. | |||
| Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to | |||
| :class:`FP16_Optimizer`'s constructor. | |||
| Args: | |||
| scale (float, optional, default=1.0): The loss scale. | |||
| """ | |||
| def __init__(self, scale=1): | |||
| self.cur_scale = scale | |||
| # `params` is a list / generator of torch.Variable | |||
| def has_overflow(self, params): | |||
| return False | |||
| # `x` is a torch.Tensor | |||
| def _has_inf_or_nan(x): | |||
| return False | |||
| def update_scale(self, overflow): | |||
| pass | |||
| @property | |||
| def loss_scale(self): | |||
| return self.cur_scale | |||
| def scale_gradient(self, module, grad_in, grad_out): | |||
| return tuple(self.loss_scale * g for g in grad_in) | |||
| def backward(self, loss, retain_graph=False): | |||
| scaled_loss = loss * self.loss_scale | |||
| scaled_loss.backward(retain_graph=retain_graph) | |||
| class DynamicLossScaler: | |||
| """ | |||
| Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` | |||
| indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of | |||
| :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` | |||
| operates, because the default options can be changed using the | |||
| the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. | |||
| Loss scaling is designed to combat the problem of underflowing gradients encountered at long | |||
| times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss | |||
| scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are | |||
| encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has | |||
| occurred. | |||
| :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, | |||
| and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. | |||
| If a certain number of iterations occur without overflowing gradients detected, | |||
| :class:`DynamicLossScaler` increases the loss scale once more. | |||
| In this way :class:`DynamicLossScaler` attempts to "ride the edge" of | |||
| always using the highest loss scale possible without incurring overflow. | |||
| Args: | |||
| init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` | |||
| scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa | |||
| scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa | |||
| """ | |||
| def __init__(self, | |||
| init_scale=2**32, | |||
| scale_factor=2., | |||
| scale_window=1000, | |||
| min_scale=1, | |||
| delayed_shift=1, | |||
| consecutive_hysteresis=False): | |||
| self.cur_scale = init_scale | |||
| self.cur_iter = 0 | |||
| self.last_overflow_iter = -1 | |||
| self.scale_factor = scale_factor | |||
| self.scale_window = scale_window | |||
| self.min_scale = min_scale | |||
| self.delayed_shift = delayed_shift | |||
| self.cur_hysteresis = delayed_shift | |||
| self.consecutive_hysteresis = consecutive_hysteresis | |||
| # `params` is a list / generator of torch.Variable | |||
| def has_overflow_serial(self, params): | |||
| for p in params: | |||
| if p.grad is not None and DynamicLossScaler._has_inf_or_nan( | |||
| p.grad.data): | |||
| return True | |||
| return False | |||
| def has_overflow(self, params): | |||
| overflow = self.has_overflow_serial(params) | |||
| overflow_gpu = torch.cuda.ByteTensor([overflow]) | |||
| overflow = overflow_gpu[0].item() | |||
| return bool(overflow) | |||
| # `x` is a torch.Tensor | |||
| def _has_inf_or_nan(x): | |||
| try: | |||
| # if x is half, the .float() incurs an additional deep copy, but it's necessary if | |||
| # Pytorch's .sum() creates a one-element tensor of the same type as x | |||
| # (which is true for some recent version of pytorch). | |||
| cpu_sum = float(x.float().sum()) | |||
| # More efficient version that can be used if .sum() returns a Python scalar | |||
| # cpu_sum = float(x.sum()) | |||
| except RuntimeError as instance: | |||
| # We want to check if inst is actually an overflow exception. | |||
| # RuntimeError could come from a different error. | |||
| # If so, we still want the exception to propagate. | |||
| if 'value cannot be converted' not in instance.args[0]: | |||
| raise | |||
| return True | |||
| else: | |||
| if cpu_sum == float( | |||
| 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |||
| return True | |||
| return False | |||
| # `overflow` is boolean indicating whether the gradient overflowed | |||
| def update_scale(self, overflow): | |||
| if not hasattr(self, 'min_scale'): | |||
| self.min_scale = 1 | |||
| if not hasattr(self, 'delayed_shift'): | |||
| self.delayed_shift = 1 | |||
| if not hasattr(self, 'cur_hysteresis'): | |||
| self.cur_hysteresis = 1 | |||
| if not hasattr(self, 'consecutive_hysteresis'): | |||
| self.consecutive_hysteresis = True | |||
| if overflow: | |||
| # self.cur_scale /= self.scale_factor | |||
| if self.delayed_shift == 1 or self.cur_hysteresis == 1: | |||
| self.cur_scale = max(self.cur_scale / self.scale_factor, | |||
| self.min_scale) | |||
| else: | |||
| self.cur_hysteresis -= 1 | |||
| self.last_overflow_iter = self.cur_iter | |||
| else: | |||
| if self.consecutive_hysteresis: | |||
| self.cur_hysteresis = self.delayed_shift | |||
| if (self.cur_iter | |||
| - self.last_overflow_iter) % self.scale_window == 0: | |||
| if not self.consecutive_hysteresis: | |||
| self.cur_hysteresis = self.delayed_shift | |||
| self.cur_scale *= self.scale_factor | |||
| self.cur_iter += 1 | |||
| @property | |||
| def loss_scale(self): | |||
| return self.cur_scale | |||
| def scale_gradient(self, module, grad_in, grad_out): | |||
| return tuple(self.loss_scale * g for g in grad_in) | |||
| def backward(self, loss, retain_graph=False): | |||
| scaled_loss = loss * self.loss_scale | |||
| scaled_loss.backward(retain_graph=retain_graph) | |||
| ############################################################## | |||
| # Example usage below here -- assuming it's in a separate file | |||
| ############################################################## | |||
| """ | |||
| TO-DO separate out into an example. | |||
| if __name__ == "__main__": | |||
| import torch | |||
| from torch.autograd import Variable | |||
| from dynamic_loss_scaler import DynamicLossScaler | |||
| # N is batch size; D_in is input dimension; | |||
| # H is hidden dimension; D_out is output dimension. | |||
| N, D_in, H, D_out = 64, 1000, 100, 10 | |||
| # Create random Tensors to hold inputs and outputs, and wrap them in Variables. | |||
| x = Variable(torch.randn(N, D_in), requires_grad=False) | |||
| y = Variable(torch.randn(N, D_out), requires_grad=False) | |||
| w1 = Variable(torch.randn(D_in, H), requires_grad=True) | |||
| w2 = Variable(torch.randn(H, D_out), requires_grad=True) | |||
| parameters = [w1, w2] | |||
| learning_rate = 1e-6 | |||
| optimizer = torch.optim.SGD(parameters, lr=learning_rate) | |||
| loss_scaler = DynamicLossScaler() | |||
| for t in range(500): | |||
| y_pred = x.mm(w1).clamp(min=0).mm(w2) | |||
| loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale | |||
| print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) | |||
| print('Iter {} scaled loss: {}'.format(t, loss.data[0])) | |||
| print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) | |||
| # Run backprop | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| # Check for overflow | |||
| has_overflow = DynamicLossScaler.has_overflow(parameters) | |||
| # If no overflow, unscale grad and update as usual | |||
| if not has_overflow: | |||
| for param in parameters: | |||
| param.grad.data.mul_(1. / loss_scaler.loss_scale) | |||
| optimizer.step() | |||
| # Otherwise, don't do anything -- ie, skip iteration | |||
| else: | |||
| print('OVERFLOW!') | |||
| # Update loss scale for next iteration | |||
| loss_scaler.update_scale(has_overflow) | |||
| """ | |||
| @@ -172,6 +172,7 @@ class OfaTasksTest(unittest.TestCase): | |||
| ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | |||
| image = 'data/test/images/visual_grounding.png' | |||
| text = '一个圆头的蓝色宝可梦' | |||
| text = '火' | |||
| input = {'image': image, 'text': text} | |||
| result = ofa_pipe(input) | |||
| print(result) | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from modelscope.trainers.multi_modal.ofa import OFATrainer | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestOfaTrainer(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||
| self.trainer = OFATrainer(model_id) | |||
| self.trainer.train() | |||
| shutil.rmtree(self.trainer.save_dir) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||