Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10318299master
@@ -281,6 +281,7 @@ class Trainers(object): | |||
# multi-modal trainers | |||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
ofa = 'ofa' | |||
# cv trainers | |||
image_instance_segmentation = 'image-instance-segmentation' | |||
@@ -375,6 +376,9 @@ class Metrics(object): | |||
accuracy = 'accuracy' | |||
audio_noise_metric = 'audio-noise-metric' | |||
# text gen | |||
BLEU = 'bleu' | |||
# metrics for image denoise task | |||
image_denoise_metric = 'image-denoise-metric' | |||
@@ -395,6 +399,8 @@ class Metrics(object): | |||
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | |||
# metric for inpainting task | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
# metric for ocr | |||
NED = 'ned' | |||
class Optimizers(object): | |||
@@ -17,6 +17,8 @@ if TYPE_CHECKING: | |||
from .token_classification_metric import TokenClassificationMetric | |||
from .video_summarization_metric import VideoSummarizationMetric | |||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | |||
from .accuracy_metric import AccuracyMetric | |||
from .bleu_metric import BleuMetric | |||
from .image_inpainting_metric import ImageInpaintingMetric | |||
else: | |||
@@ -36,6 +38,8 @@ else: | |||
'video_summarization_metric': ['VideoSummarizationMetric'], | |||
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | |||
'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
'accuracy_metric': ['AccuracyMetric'], | |||
'bleu_metric': ['BleuMetric'], | |||
} | |||
import sys | |||
@@ -0,0 +1,46 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
import numpy as np | |||
from modelscope.metainfo import Metrics | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | |||
class AccuracyMetric(Metric): | |||
"""The metric computation class for classification classes. | |||
This metric class calculates accuracy for the whole input batches. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.preds = [] | |||
self.labels = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
ground_truths = inputs[label_name] | |||
eval_results = outputs[label_name] | |||
assert type(ground_truths) == type(eval_results) | |||
if isinstance(ground_truths, list): | |||
self.preds.extend(eval_results) | |||
self.labels.extend(ground_truths) | |||
elif isinstance(ground_truths, np.ndarray): | |||
self.preds.extend(eval_results.tolist()) | |||
self.labels.extend(ground_truths.tolist()) | |||
else: | |||
raise 'only support list or np.ndarray' | |||
def evaluate(self): | |||
assert len(self.preds) == len(self.labels) | |||
return { | |||
MetricKeys.ACCURACY: (np.asarray([ | |||
pred == ref for pred, ref in zip(self.preds, self.labels) | |||
])).mean().item() | |||
} |
@@ -0,0 +1,42 @@ | |||
from itertools import zip_longest | |||
from typing import Dict | |||
import sacrebleu | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
EVAL_BLEU_ORDER = 4 | |||
@METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) | |||
class BleuMetric(Metric): | |||
"""The metric computation bleu for text generation classes. | |||
This metric class calculates accuracy for the whole input batches. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) | |||
self.hyp_name = kwargs.get('hyp_name', 'hyp') | |||
self.ref_name = kwargs.get('ref_name', 'ref') | |||
self.refs = list() | |||
self.hyps = list() | |||
def add(self, outputs: Dict, inputs: Dict): | |||
self.refs.extend(inputs[self.ref_name]) | |||
self.hyps.extend(outputs[self.hyp_name]) | |||
def evaluate(self): | |||
if self.eval_tokenized_bleu: | |||
bleu = sacrebleu.corpus_bleu( | |||
self.hyps, list(zip_longest(*self.refs)), tokenize='none') | |||
else: | |||
bleu = sacrebleu.corpus_bleu(self.hyps, | |||
list(zip_longest(*self.refs))) | |||
return { | |||
MetricKeys.BLEU_4: bleu.score, | |||
} |
@@ -23,6 +23,7 @@ class MetricKeys(object): | |||
BLEU_4 = 'bleu-4' | |||
ROUGE_1 = 'rouge-1' | |||
ROUGE_L = 'rouge-l' | |||
NED = 'ned' # ocr metric | |||
task_default_metrics = { | |||
@@ -0,0 +1 @@ | |||
__author__ = 'tylin' |
@@ -0,0 +1,57 @@ | |||
# Filename: ciderD.py | |||
# | |||
# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric | |||
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) | |||
# | |||
# Creation Date: Sun Feb 8 14:16:54 2015 | |||
# | |||
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu> | |||
from __future__ import absolute_import, division, print_function | |||
from .ciderD_scorer import CiderScorer | |||
class CiderD: | |||
""" | |||
Main Class to compute the CIDEr metric | |||
""" | |||
def __init__(self, n=4, sigma=6.0, df='corpus'): | |||
# set cider to sum over 1 to 4-grams | |||
self._n = n | |||
# set the standard deviation parameter for gaussian penalty | |||
self._sigma = sigma | |||
# set which where to compute document frequencies from | |||
self._df = df | |||
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) | |||
def compute_score(self, gts, res): | |||
""" | |||
Main function to compute CIDEr score | |||
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence> | |||
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence> | |||
:return: cider (float) : computed CIDEr score for the corpus | |||
""" # noqa | |||
# clear all the previous hypos and refs | |||
tmp_cider_scorer = self.cider_scorer.copy_empty() | |||
tmp_cider_scorer.clear() | |||
for res_id in res: | |||
hypo = res_id['caption'] | |||
ref = gts[res_id['image_id']] | |||
# Sanity check. | |||
assert (type(hypo) is list) | |||
assert (len(hypo) == 1) | |||
assert (type(ref) is list) | |||
assert (len(ref) > 0) | |||
tmp_cider_scorer += (hypo[0], ref) | |||
(score, scores) = tmp_cider_scorer.compute_score() | |||
return score, scores | |||
def method(self): | |||
return 'CIDEr-D' |
@@ -0,0 +1,233 @@ | |||
#!/usr/bin/env python | |||
# Tsung-Yi Lin <tl483@cornell.edu> | |||
# Ramakrishna Vedantam <vrama91@vt.edu> | |||
from __future__ import absolute_import, division, print_function | |||
import copy | |||
import math | |||
import os | |||
import pdb | |||
from collections import defaultdict | |||
import numpy as np | |||
import six | |||
from six.moves import cPickle | |||
def precook(s, n=4, out=False): | |||
""" | |||
Takes a string as input and returns an object that can be given to | |||
either cook_refs or cook_test. This is optional: cook_refs and cook_test | |||
can take string arguments as well. | |||
:param s: string : sentence to be converted into ngrams | |||
:param n: int : number of ngrams for which representation is calculated | |||
:return: term frequency vector for occuring ngrams | |||
""" | |||
words = s.split() | |||
counts = defaultdict(int) | |||
for k in range(1, n + 1): | |||
for i in range(len(words) - k + 1): | |||
ngram = tuple(words[i:i + k]) | |||
counts[ngram] += 1 | |||
return counts | |||
def cook_refs(refs, n=4): # lhuang: oracle will call with "average" | |||
'''Takes a list of reference sentences for a single segment | |||
and returns an object that encapsulates everything that BLEU | |||
needs to know about them. | |||
:param refs: list of string : reference sentences for some image | |||
:param n: int : number of ngrams for which (ngram) representation is calculated | |||
:return: result (list of dict) | |||
''' | |||
return [precook(ref, n) for ref in refs] | |||
def cook_test(test, n=4): | |||
'''Takes a test sentence and returns an object that | |||
encapsulates everything that BLEU needs to know about it. | |||
:param test: list of string : hypothesis sentence for some image | |||
:param n: int : number of ngrams for which (ngram) representation is calculated | |||
:return: result (dict) | |||
''' | |||
return precook(test, n, True) | |||
class CiderScorer(object): | |||
"""CIDEr scorer. | |||
""" | |||
def copy(self): | |||
''' copy the refs.''' | |||
new = CiderScorer(n=self.n) | |||
new.ctest = copy.copy(self.ctest) | |||
new.crefs = copy.copy(self.crefs) | |||
return new | |||
def copy_empty(self): | |||
new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) | |||
new.df_mode = self.df_mode | |||
new.ref_len = self.ref_len | |||
new.document_frequency = self.document_frequency | |||
return new | |||
def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): | |||
''' singular instance ''' | |||
self.n = n | |||
self.sigma = sigma | |||
self.crefs = [] | |||
self.ctest = [] | |||
self.df_mode = df_mode | |||
self.ref_len = None | |||
if self.df_mode != 'corpus': | |||
pkl_file = cPickle.load( | |||
open(df_mode, 'rb'), | |||
**(dict(encoding='latin1') if six.PY3 else {})) | |||
self.ref_len = np.log(float(pkl_file['ref_len'])) | |||
self.document_frequency = pkl_file['document_frequency'] | |||
else: | |||
self.document_frequency = None | |||
self.cook_append(test, refs) | |||
def clear(self): | |||
self.crefs = [] | |||
self.ctest = [] | |||
def cook_append(self, test, refs): | |||
'''called by constructor and __iadd__ to avoid creating new instances.''' | |||
if refs is not None: | |||
self.crefs.append(cook_refs(refs)) | |||
if test is not None: | |||
self.ctest.append(cook_test(test)) # N.B.: -1 | |||
else: | |||
self.ctest.append( | |||
None) # lens of crefs and ctest have to match | |||
def size(self): | |||
assert len(self.crefs) == len( | |||
self.ctest), 'refs/test mismatch! %d<>%d' % (len( | |||
self.crefs), len(self.ctest)) | |||
return len(self.crefs) | |||
def __iadd__(self, other): | |||
'''add an instance (e.g., from another sentence).''' | |||
if type(other) is tuple: | |||
# avoid creating new CiderScorer instances | |||
self.cook_append(other[0], other[1]) | |||
else: | |||
self.ctest.extend(other.ctest) | |||
self.crefs.extend(other.crefs) | |||
return self | |||
def compute_doc_freq(self): | |||
""" | |||
Compute term frequency for reference data. | |||
This will be used to compute idf (inverse document frequency later) | |||
The term frequency is stored in the object | |||
:return: None | |||
""" | |||
for refs in self.crefs: | |||
# refs, k ref captions of one image | |||
for ngram in set([ | |||
ngram for ref in refs for (ngram, count) in ref.items() | |||
]): # noqa | |||
self.document_frequency[ngram] += 1 | |||
def compute_cider(self): | |||
def counts2vec(cnts): | |||
""" | |||
Function maps counts of ngram to vector of tfidf weights. | |||
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. | |||
The n-th entry of array denotes length of n-grams. | |||
:param cnts: | |||
:return: vec (array of dict), norm (array of float), length (int) | |||
""" | |||
vec = [defaultdict(float) for _ in range(self.n)] | |||
length = 0 | |||
norm = [0.0 for _ in range(self.n)] | |||
for (ngram, term_freq) in cnts.items(): | |||
# give word count 1 if it doesn't appear in reference corpus | |||
df = np.log(max(1.0, self.document_frequency[ngram])) | |||
# ngram index | |||
n = len(ngram) - 1 | |||
# tf (term_freq) * idf (precomputed idf) for n-grams | |||
vec[n][ngram] = float(term_freq) * (self.ref_len - df) | |||
# compute norm for the vector. the norm will be used for computing similarity | |||
norm[n] += pow(vec[n][ngram], 2) | |||
if n == 1: | |||
length += term_freq | |||
norm = [np.sqrt(n) for n in norm] | |||
return vec, norm, length | |||
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): | |||
''' | |||
Compute the cosine similarity of two vectors. | |||
:param vec_hyp: array of dictionary for vector corresponding to hypothesis | |||
:param vec_ref: array of dictionary for vector corresponding to reference | |||
:param norm_hyp: array of float for vector corresponding to hypothesis | |||
:param norm_ref: array of float for vector corresponding to reference | |||
:param length_hyp: int containing length of hypothesis | |||
:param length_ref: int containing length of reference | |||
:return: array of score for each n-grams cosine similarity | |||
''' | |||
delta = float(length_hyp - length_ref) | |||
# measure consine similarity | |||
val = np.array([0.0 for _ in range(self.n)]) | |||
for n in range(self.n): | |||
# ngram | |||
for (ngram, count) in vec_hyp[n].items(): | |||
# vrama91 : added clipping | |||
val[n] += min(vec_hyp[n][ngram], | |||
vec_ref[n][ngram]) * vec_ref[n][ngram] | |||
if (norm_hyp[n] != 0) and (norm_ref[n] != 0): | |||
val[n] /= (norm_hyp[n] * norm_ref[n]) | |||
assert (not math.isnan(val[n])) | |||
# vrama91: added a length based gaussian penalty | |||
val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) | |||
return val | |||
# compute log reference length | |||
if self.df_mode == 'corpus': | |||
self.ref_len = np.log(float(len(self.crefs))) | |||
# elif self.df_mode == "coco-val-df": | |||
# if coco option selected, use length of coco-val set | |||
# self.ref_len = np.log(float(40504)) | |||
scores = [] | |||
for test, refs in zip(self.ctest, self.crefs): | |||
# compute vector for test captions | |||
vec, norm, length = counts2vec(test) | |||
# compute vector for ref captions | |||
score = np.array([0.0 for _ in range(self.n)]) | |||
for ref in refs: | |||
vec_ref, norm_ref, length_ref = counts2vec(ref) | |||
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) | |||
# change by vrama91 - mean of ngram scores, instead of sum | |||
score_avg = np.mean(score) | |||
# divide by number of references | |||
score_avg /= len(refs) | |||
# multiply score by 10 | |||
score_avg *= 10.0 | |||
# append score of an image to the score list | |||
scores.append(score_avg) | |||
return scores | |||
def compute_score(self, option=None, verbose=0): | |||
# compute idf | |||
if self.df_mode == 'corpus': | |||
self.document_frequency = defaultdict(float) | |||
self.compute_doc_freq() | |||
# assert to check document frequency | |||
assert (len(self.ctest) >= max(self.document_frequency.values())) | |||
# import json for now and write the corresponding files | |||
# compute cider score | |||
score = self.compute_cider() | |||
# debug | |||
# print score | |||
return np.mean(np.array(score)), np.array(score) |
@@ -148,7 +148,7 @@ class BeamSearch(Search): | |||
scores_buf = top_prediction[0] | |||
indices_buf = top_prediction[1] | |||
# Project back into relative indices and beams | |||
beams_buf = indices_buf // vocab_size | |||
beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') | |||
indices_buf = indices_buf.fmod(vocab_size) | |||
# At this point, beams_buf and indices_buf are single-dim and contain relative indices | |||
@@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module): | |||
attn = torch.empty(bsz * beam_size, | |||
avg_attn_scores.size(1), | |||
max_len + 2).to(scores) | |||
# print("+++++++ debug attention shape +++++++") | |||
# print("attn", attn.shape) | |||
# print("avg_attn_scores", avg_attn_scores.shape) | |||
attn[:, :, step + 1].copy_(avg_attn_scores) | |||
# print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) | |||
# print("attn", attn.shape) | |||
scores = scores.type_as(lprobs) | |||
eos_bbsz_idx = torch.empty(0).to( | |||
@@ -404,8 +399,28 @@ class SequenceGenerator(nn.Module): | |||
self.search.set_src_lengths(src_lengths) | |||
if self.repeat_ngram_blocker is not None: | |||
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||
beam_size, step) | |||
# process prefix_tokens | |||
p_toks_len = prefix_tokens.ne(self.pad).sum( | |||
dim=1) if prefix_tokens is not None else None | |||
if p_toks_len is not None: | |||
p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( | |||
1, beam_size).view(-1) | |||
no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size | |||
out_prefix = p_toks_len_beam < ( | |||
step + no_repeat_ngram_size - 1) | |||
else: | |||
out_prefix = torch.ones(bsz * beam_size).bool() | |||
ngram_blocker_tokens = tokens[out_prefix] | |||
ngram_blocker_lprobs = lprobs[out_prefix] | |||
ngram_blocker_bsz = torch.div( | |||
out_prefix.sum(), beam_size, rounding_mode='trunc') | |||
lprobs[out_prefix] = self.repeat_ngram_blocker( | |||
tokens=ngram_blocker_tokens, | |||
lprobs=ngram_blocker_lprobs, | |||
bsz=ngram_blocker_bsz, | |||
beam_size=beam_size, | |||
step=step) | |||
# Shape: (batch, cand_size) | |||
cand_scores, cand_indices, cand_beams = self.search.step( | |||
@@ -415,7 +430,6 @@ class SequenceGenerator(nn.Module): | |||
tokens[:, :step + 1], | |||
original_batch_idxs, | |||
) | |||
# cand_bbsz_idx contains beam indices for the top candidate | |||
# hypotheses, with a range of values: [0, bsz*beam_size), | |||
# and dimensions: [bsz, cand_size] | |||
@@ -671,7 +685,7 @@ class SequenceGenerator(nn.Module): | |||
cum_unfin.append(prev) | |||
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) | |||
unfin_idx = bbsz_idx // beam_size | |||
unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') | |||
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) | |||
# Create a set of "{sent}{unfin_idx}", where | |||
@@ -19,6 +19,7 @@ from dataclasses import dataclass | |||
from typing import Dict, List, Optional, Tuple | |||
import torch | |||
from packaging import version | |||
from torch import Tensor, nn | |||
from torch.nn import functional as F | |||
from transformers.activations import ACT2FN | |||
@@ -40,6 +41,8 @@ logger = logging.get_logger(__name__) | |||
_CHECKPOINT_FOR_DOC = 'ofa-base' | |||
_CONFIG_FOR_DOC = 'OFAConfig' | |||
_TOKENIZER_FOR_DOC = 'OFATokenizer' | |||
TORCH_VERSION = version.parse(torch.__version__) | |||
TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') | |||
DEFAULT_MAX_SOURCE_POSITIONS = 1024 | |||
DEFAULT_MAX_TARGET_POSITIONS = 1024 | |||
@@ -51,6 +54,7 @@ OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |||
'ofa-medium', | |||
'ofa-base', | |||
'ofa-large', | |||
'ofa-huge', | |||
] | |||
try: | |||
@@ -114,7 +118,11 @@ def make_image_bucket_position(bucket_size, num_relative_distance): | |||
""" | |||
coords_h = torch.arange(bucket_size) | |||
coords_w = torch.arange(bucket_size) | |||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: | |||
coords = torch.stack( | |||
torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww | |||
else: | |||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) | |||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
relative_coords = coords_flatten[:, :, None] - \ | |||
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |||
@@ -8,7 +8,7 @@ OFA_TASK_KEY_MAPPING = { | |||
Tasks.text_summarization: OutputKeys.TEXT, | |||
Tasks.visual_question_answering: OutputKeys.TEXT, | |||
Tasks.visual_grounding: OutputKeys.BOXES, | |||
Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), | |||
Tasks.text_classification: OutputKeys.LABELS, | |||
Tasks.image_classification: OutputKeys.LABELS, | |||
Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), | |||
Tasks.visual_entailment: OutputKeys.LABELS, | |||
} |
@@ -1,8 +1,10 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
import string | |||
from functools import partial | |||
from os import path as osp | |||
from typing import Any, Dict | |||
from typing import Any, Callable, Dict, List, Optional, Union | |||
import json | |||
import torch.cuda | |||
@@ -10,7 +12,6 @@ import torch.nn.functional as F | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||
@@ -66,10 +67,9 @@ class OfaForAllTasks(TorchModel): | |||
self.gen_type = self.cfg.model.get('gen_type', 'generation') | |||
assert self.gen_type in ['generation', 'traverse'], \ | |||
'model.gen_type must be in ["generation", "traverse"]' | |||
self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||
else torch.device('cpu') | |||
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id | |||
]).to(self._device) | |||
self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id]) | |||
self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id]) | |||
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id]) | |||
self.index2ans = {} | |||
self.ans2label_dict = {} | |||
self.load_ans2label() | |||
@@ -90,7 +90,8 @@ class OfaForAllTasks(TorchModel): | |||
self.val_masks_l = [] | |||
self.build_trie() | |||
sg_args['constraint_trie'] = self.constraint_trie | |||
self.model.to(self._device) | |||
else: | |||
self.constraint_trie = None | |||
self.generator = sg.SequenceGenerator(**sg_args) | |||
inference_d = { | |||
'generation': self._text_gen_inference, | |||
@@ -108,8 +109,16 @@ class OfaForAllTasks(TorchModel): | |||
} | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
input = move_to_device(input, self.model.device) | |||
if self.model.training: | |||
return self.model(**input['net_input']) | |||
else: | |||
return self.inference(input) | |||
def inference(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
ret = self.task_inference_mapping[self.cfg.task](input) | |||
ret['samples'] = input['samples'] | |||
if 'samples' in input: | |||
ret['samples'] = input['samples'] | |||
for key in [ | |||
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||
OutputKeys.LABELS, OutputKeys.SCORES | |||
@@ -118,21 +127,33 @@ class OfaForAllTasks(TorchModel): | |||
ret[key] = None | |||
return ret | |||
def postprocess(self, input: Dict[str, Tensor], | |||
**kwargs) -> Dict[str, Tensor]: | |||
if self.cfg.task == Tasks.image_captioning: | |||
caption = [ | |||
cap.translate(self.transtab).strip() | |||
for cap in input[OutputKeys.CAPTION] | |||
] | |||
input[OutputKeys.CAPTION] = caption | |||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
if not self.model.training and self.cfg.task == Tasks.image_captioning: | |||
caption = input[OutputKeys.CAPTION] | |||
result_l = list() | |||
for cap in caption: | |||
result_l.append(cap.translate(self.transtab).strip()) | |||
input[OutputKeys.CAPTION] = result_l | |||
return input | |||
def _text_gen_inference(self, input): | |||
input = move_to_device(input, self._device) | |||
gen_output = self.generator.generate([self.model], input) | |||
gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||
result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) | |||
gen_outputs = self.generator.generate([self.model], | |||
input, | |||
prefix_tokens=input.get( | |||
'prefix_tokens', None)) | |||
gen_l = list() | |||
for idx, gen_out in enumerate(gen_outputs): | |||
if len(gen_out) > 0: | |||
decode_tokens = gen_out[0]['tokens'] | |||
if 'prefix_tokens' in input: | |||
prefix_len = input['prefix_tokens'][idx].ne( | |||
self.pad_item.to(self.model.device)).sum() | |||
decode_tokens = decode_tokens[prefix_len:] | |||
gen_l.append(decode_tokens) | |||
else: | |||
gen_l.append('') | |||
result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) | |||
result = [item.strip() for item in result] | |||
# text generation tasks have no score | |||
ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} | |||
if self.cfg.task.endswith('classification'): | |||
@@ -140,7 +161,6 @@ class OfaForAllTasks(TorchModel): | |||
return ret | |||
def _visual_grounding_inference(self, input): | |||
input = move_to_device(input, self._device) | |||
gen_output = self.generator.generate([self.model], input) | |||
tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||
region_coord_l = list() | |||
@@ -160,7 +180,6 @@ class OfaForAllTasks(TorchModel): | |||
} | |||
def _traverse_inference(self, input): | |||
input = move_to_device(input, self._device) | |||
encoder_input = dict() | |||
for key in input['net_input'].keys(): | |||
encoder_input[key] = input['net_input'][key] | |||
@@ -170,13 +189,14 @@ class OfaForAllTasks(TorchModel): | |||
valid_size = len(val_ans) | |||
valid_tgt_items = [ | |||
torch.cat([ | |||
torch.tensor(decoder_prompt[1:]), valid_answer, | |||
torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer, | |||
self.eos_item | |||
]) for decoder_prompt in input['decoder_prompts'] | |||
for valid_answer in val_ans | |||
] | |||
valid_prev_items = [ | |||
torch.cat([torch.tensor(decoder_prompt), valid_answer]) | |||
torch.cat( | |||
[torch.tensor(decoder_prompt).to('cpu'), valid_answer]) | |||
for decoder_prompt in input['decoder_prompts'] | |||
for valid_answer in val_ans | |||
] | |||
@@ -184,19 +204,19 @@ class OfaForAllTasks(TorchModel): | |||
torch.cat([ | |||
torch.zeros( | |||
len(decoder_prompt) - 1, | |||
valid_constraint_mask.size(1)).bool().to(self._device), | |||
valid_constraint_mask.size(1)).bool(), | |||
valid_constraint_mask], dim=0) # yapf: disable | |||
for decoder_prompt in input['decoder_prompts'] # yapf: disable | |||
for valid_constraint_mask in val_masks] # yapf: disable | |||
valid_tgt = collate_tokens( | |||
valid_tgt_items, | |||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||
valid_prev_output = collate_tokens( | |||
valid_prev_items, | |||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||
val_masks = collate_tokens( | |||
valid_constraint_mask_items, | |||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||
new_encoder_out = { | |||
'last_hidden_state': | |||
encoder_out['last_hidden_state'].repeat_interleave( | |||
@@ -271,10 +291,23 @@ class OfaForAllTasks(TorchModel): | |||
self.val_masks_l += [ | |||
constraint_mask_list[i:i + self.val_batch_size] | |||
] | |||
self.val_ans_l = move_to_device(self.val_ans_l, self._device) | |||
self.val_masks_l = move_to_device(self.val_masks_l, self._device) | |||
def load_ans2label(self): | |||
if self.cfg.model.get('answer2label', None): | |||
filename = osp.join(self.model_dir, self.cfg.model.answer2label) | |||
self.ans2label_dict = json.load(open(filename)) | |||
ans2label_file = osp.join(self.model_dir, | |||
self.cfg.model.answer2label) | |||
with open(ans2label_file, 'r') as reader: | |||
self.ans2label_dict = json.load(reader) | |||
def save_pretrained(self, | |||
target_folder: Union[str, os.PathLike], | |||
save_checkpoint_names: Union[str, List[str]] = None, | |||
save_function: Callable = None, | |||
config: Optional[dict] = None, | |||
**kwargs): | |||
super(OfaForAllTasks, self). \ | |||
save_pretrained(target_folder=target_folder, | |||
save_checkpoint_names=save_checkpoint_names, | |||
save_function=partial(save_function, with_meta=False), | |||
config=config, | |||
**kwargs) |
@@ -13,6 +13,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.device import get_device | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -36,6 +37,7 @@ class ImageClassificationPipeline(Pipeline): | |||
else: | |||
raise NotImplementedError | |||
pipe_model.model.eval() | |||
pipe_model.to(get_device()) | |||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
@@ -1,5 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from io import BytesIO | |||
from typing import Any, Dict, List, Tuple, Union | |||
import torch | |||
@@ -15,6 +16,7 @@ from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
from .ofa import * # noqa | |||
from .ofa.utils.collate import collate_fn | |||
from .ofa.utils.constant import OFA_TASK_KEY_MAPPING | |||
__all__ = [ | |||
'OfaPreprocessor', | |||
@@ -26,11 +28,16 @@ __all__ = [ | |||
Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) | |||
class OfaPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
def __init__(self, | |||
model_dir: str, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
model_dir (str): model path | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super().__init__(*args, **kwargs) | |||
preprocess_mapping = { | |||
@@ -45,25 +52,18 @@ class OfaPreprocessor(Preprocessor): | |||
Tasks.text_summarization: OfaSummarizationPreprocessor, | |||
Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||
} | |||
input_key_mapping = { | |||
Tasks.ocr_recognition: ['image'], | |||
Tasks.image_captioning: ['image'], | |||
Tasks.image_classification: ['image'], | |||
Tasks.text_summarization: ['text'], | |||
Tasks.text_classification: ['text', 'text2'], | |||
Tasks.visual_grounding: ['image', 'text'], | |||
Tasks.visual_question_answering: ['image', 'text'], | |||
Tasks.visual_entailment: ['image', 'text', 'text2'], | |||
Tasks.text_to_image_synthesis: ['text'] | |||
} | |||
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | |||
model_dir) | |||
self.cfg = Config.from_file( | |||
osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, | |||
model_dir) | |||
self.keys = input_key_mapping[self.cfg.task] | |||
self.preprocess = preprocess_mapping[self.cfg.task]( | |||
cfg=self.cfg, model_dir=model_dir, mode=mode) | |||
self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task] | |||
self.tokenizer = self.preprocess.tokenizer | |||
if kwargs.get('no_collate', None): | |||
self.no_collate = True | |||
else: | |||
self.no_collate = False | |||
# just for modelscope demo | |||
def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: | |||
@@ -74,20 +74,37 @@ class OfaPreprocessor(Preprocessor): | |||
data[key] = item | |||
return data | |||
def _ofa_input_compatibility_conversion(self, data): | |||
if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | |||
if isinstance(data['image'], str): | |||
image = load_image(data['image']) | |||
else: | |||
image = data['image'] | |||
if image.mode != 'RGB': | |||
image = image.convert('RGB') | |||
img_buffer = BytesIO() | |||
image.save(img_buffer, format='JPEG') | |||
data['image'] = Image.open(img_buffer) | |||
return data | |||
def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||
**kwargs) -> Dict[str, Any]: | |||
if isinstance(input, dict): | |||
data = input | |||
else: | |||
data = self._build_dict(input) | |||
data = self._ofa_input_compatibility_conversion(data) | |||
sample = self.preprocess(data) | |||
str_data = dict() | |||
for k, v in data.items(): | |||
str_data[k] = str(v) | |||
sample['sample'] = str_data | |||
return collate_fn([sample], | |||
pad_idx=self.tokenizer.pad_token_id, | |||
eos_idx=self.tokenizer.eos_token_id) | |||
if self.no_collate: | |||
return sample | |||
else: | |||
return collate_fn([sample], | |||
pad_idx=self.tokenizer.pad_token_id, | |||
eos_idx=self.tokenizer.eos_token_id) | |||
@PREPROCESSORS.register_module( | |||
@@ -140,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): | |||
def image_open(self, path: str) -> Tuple[Image.Image, int]: | |||
if path not in self._image_map: | |||
index = len(self._image_map) | |||
self._image_map[path] = (load_image(path), index) | |||
self._image_map[path] = (Image.open(path), index) | |||
return self._image_map[path] | |||
def __call__( | |||
@@ -1,26 +1,31 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import re | |||
import string | |||
from os import path as osp | |||
import json | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.trie import Trie | |||
from .utils.constant import OFA_TASK_KEY_MAPPING | |||
from .utils.random_help import set_torch_seed | |||
class OfaBasePreprocessor: | |||
def __init__(self, cfg, model_dir): | |||
"""preprocess the data | |||
def __init__(self, cfg, model_dir, mode, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
""" | |||
self.cfg = cfg | |||
self.mode = mode | |||
self.language = self.cfg.model.get('language', 'en') | |||
if self.language == 'en': | |||
tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
@@ -41,6 +46,7 @@ class OfaBasePreprocessor: | |||
for key, value in tokenizer.get_vocab().items() | |||
} | |||
self.max_src_length = cfg.model.get('max_src_length', 256) | |||
self.max_tgt_length = cfg.model.get('max_tgt_length', 256) | |||
self.max_image_size = cfg.model.get('max_image_size', 512) | |||
self.language = self.cfg.model.get('language', 'en') | |||
self.prompt_type = self.cfg.model.get('prompt_type', 'none') | |||
@@ -56,26 +62,40 @@ class OfaBasePreprocessor: | |||
self.mean = [0.5, 0.5, 0.5] | |||
self.std = [0.5, 0.5, 0.5] | |||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
self.column_map = { | |||
key: key | |||
for key in OFA_TASK_KEY_MAPPING[self.cfg.task] | |||
} | |||
if hasattr(self.cfg, | |||
'dataset') and self.cfg.dataset.column_map is not None: | |||
for k, v in self.cfg.dataset.column_map.items(): | |||
self.column_map[k] = v | |||
self.transtab = str.maketrans( | |||
{key: None | |||
for key in string.punctuation}) | |||
self.constraint_trie = None | |||
self.index2ans = {} | |||
if self.cfg.model.get('answer2label', False): | |||
if self.cfg.model.get('answer2label', None): | |||
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | |||
ans2label_dict = json.load(open(ans2label_file, 'r')) | |||
with open(ans2label_file, 'r') as reader: | |||
ans2label_dict = json.load(reader) | |||
self.ans2label = ans2label_dict | |||
self.label2ans = {v: k for k, v in self.ans2label.items()} | |||
self.constraint_trie = Trie(tokenizer.eos_token_id) | |||
for i, answer in enumerate(ans2label_dict.keys()): | |||
answer_item = tokenizer( | |||
' ' + answer, | |||
return_tensors='pt', | |||
add_special_tokens=False).input_ids.squeeze(0) | |||
answer_item = self.tokenize_text( | |||
' ' + answer, add_bos=False, add_eos=False) | |||
self.constraint_trie.insert([tokenizer.bos_token_id] | |||
+ answer_item.tolist() | |||
+ [tokenizer.eos_token_id]) | |||
def get_inputs(self, text, add_bos=True, add_eos=True): | |||
def tokenize_text(self, text, add_bos=True, add_eos=True): | |||
if text is None: | |||
return None | |||
inputs = self.tokenizer( | |||
text, | |||
max_length=self.max_src_length, | |||
add_special_tokens=False, | |||
truncation=True, | |||
return_tensors='pt')['input_ids'].squeeze(0) | |||
if add_bos: | |||
inputs = torch.cat([self.bos_item, inputs]) | |||
@@ -85,7 +105,7 @@ class OfaBasePreprocessor: | |||
@staticmethod | |||
def pre_caption(caption, max_words=None): | |||
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\ | |||
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \ | |||
.replace('/', ' ').replace('<person>', 'person') | |||
caption = re.sub( | |||
@@ -123,3 +143,23 @@ class OfaBasePreprocessor: | |||
question = ' '.join(question_words[:max_ques_words]) | |||
return question | |||
def add_constraint_mask(self, sample): | |||
target_itm = sample['target'] | |||
len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item() | |||
if self.constraint_trie: | |||
constraint_mask = torch.zeros( | |||
(len(target_itm), len(self.tgt_dict))).bool() | |||
start_idx = len(target_itm) - len_label_itm | |||
for i in range(start_idx, len(target_itm)): | |||
constraint_prefix_token = self.bos_item.tolist( | |||
) + target_itm[start_idx:i].tolist() | |||
constraint_nodes = self.constraint_trie.get_next_layer( | |||
constraint_prefix_token) | |||
constraint_mask[i][constraint_nodes] = True | |||
sample['constraint_mask'] = constraint_mask | |||
def get_img_pil(self, path_or_url_or_pil): | |||
image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ | |||
else load_image(path_or_url_or_pil) | |||
return image |
@@ -1,42 +1,67 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Union | |||
from typing import Any, Dict | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaImageCaptioningPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize( | |||
(self.patch_image_size, self.patch_image_size), | |||
interpolation=transforms.InterpolationMode.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
image = data['image'] if isinstance( | |||
data['image'], Image.Image) else load_image(data['image']) | |||
if self.mode == ModeKeys.TRAIN: | |||
return self._build_train_sample(data) | |||
else: | |||
return self._build_infer_sample(data) | |||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
sample = self._build_infer_sample(data) | |||
target = data[self.column_map['text']] | |||
target = target.translate(self.transtab).strip() | |||
target_token_list = target.strip().split() | |||
target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
sample['target'] = self.tokenize_text(target, add_bos=False) | |||
sample['prev_output_tokens'] = torch.cat( | |||
[self.bos_item, sample['target'][:-1]]) | |||
return sample | |||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
image = self.get_img_pil(data[self.column_map['image']]) | |||
patch_image = self.patch_resize_transform(image) | |||
prompt = self.cfg.model.get('prompt', ' what does the image describe?') | |||
inputs = self.get_inputs(prompt) | |||
inputs = self.tokenize_text(prompt) | |||
sample = { | |||
'source': inputs, | |||
'patch_image': patch_image, | |||
'patch_mask': torch.tensor([True]) | |||
} | |||
if 'text' in self.column_map and self.column_map['text'] in data: | |||
sample['label'] = data[self.column_map['text']] | |||
return sample |
@@ -6,25 +6,33 @@ from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaImageClassificationPreprocessor, | |||
self).__init__(cfg, model_dir) | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize( | |||
(self.patch_image_size, self.patch_image_size), | |||
interpolation=transforms.InterpolationMode.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
@@ -34,7 +42,7 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||
data['image'], Image.Image) else load_image(data['image']) | |||
patch_image = self.patch_resize_transform(image) | |||
prompt = self.cfg.model.get('prompt', ' what does the image describe?') | |||
inputs = self.get_inputs(prompt) | |||
inputs = self.tokenize_text(prompt) | |||
sample = { | |||
'source': inputs, | |||
'patch_image': patch_image, | |||
@@ -1,7 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import random | |||
import unicodedata | |||
from typing import Any, Dict, Union | |||
from typing import Any, Dict | |||
import torch | |||
from PIL import Image | |||
@@ -10,6 +8,7 @@ from torchvision.transforms import InterpolationMode | |||
from torchvision.transforms import functional as F | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||
@@ -59,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False): | |||
class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaOcrRecognitionPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
if self.cfg.model.imagenet_default_mean_and_std: | |||
mean = IMAGENET_DEFAULT_MEAN | |||
@@ -89,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
data['image'], Image.Image) else load_image(data['image']) | |||
patch_image = self.patch_resize_transform(image) | |||
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | |||
inputs = self.get_inputs(prompt) | |||
inputs = self.tokenize_text(prompt) | |||
sample = { | |||
'source': inputs, | |||
@@ -1,19 +1,27 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaSummarizationPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
source = super().pre_caption( | |||
@@ -23,7 +31,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
prompt = self.cfg.model.get( | |||
'prompt', ' " {} " Summarize the article with a title: ') | |||
text = prompt.format(source) | |||
inputs = self.get_inputs(text) | |||
inputs = self.tokenize_text(text) | |||
if self.prompt_type == 'none': | |||
decoder_prompt = self.bos_item | |||
elif self.prompt_type == 'prev_output': | |||
@@ -1,38 +1,81 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaTextClassificationPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
if self.mode == ModeKeys.TRAIN: | |||
return self._build_train_sample(data) | |||
else: | |||
return self._build_infer_sample(data) | |||
def _build_instruction(self, data): | |||
text1 = ' '.join( | |||
data['text'].lower().strip().split()[:self.max_src_length]) | |||
text2 = ' '.join( | |||
data['text2'].lower().strip().split()[:self.max_src_length]) | |||
prompt = ' can text1 " {} " imply text2 " {} "?' | |||
text = prompt.format(text1, text2) | |||
inputs = self.get_inputs(text) | |||
instruction_itm = self.tokenize_text(text) | |||
return instruction_itm | |||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
instruction_itm = self._build_instruction(data) | |||
assert 'label' in data, 'there must has `label` column in train phase ' | |||
label = data['label'] | |||
if self.label2ans: | |||
label = self.label2ans[label] # ans | |||
label_itm = self.tokenize_text(f' {label}', add_bos=False) | |||
if self.prompt_type == 'none': | |||
target_itm = label_itm | |||
elif self.prompt_type == 'prev_output': | |||
target_itm = torch.cat([instruction_itm[1:-1], label_itm]) | |||
else: | |||
raise NotImplementedError | |||
prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]]) | |||
target_itm[:-len(label_itm)] = self.pad_item | |||
sample = { | |||
'source': instruction_itm, | |||
'target': target_itm, | |||
'prev_output_tokens': prev_output_itm, | |||
} | |||
self.add_constraint_mask(sample) | |||
return sample | |||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
instruction_itm = self._build_instruction(data) | |||
if self.prompt_type == 'none': | |||
decoder_prompt = self.bos_item | |||
elif self.prompt_type == 'src': | |||
decoder_prompt = inputs | |||
prefix_token = [] | |||
elif self.prompt_type == 'prev_output': | |||
decoder_prompt = inputs[:-1] | |||
prefix_token = instruction_itm[:-1] # remove eos | |||
else: | |||
raise NotImplementedError | |||
sample = { | |||
'source': inputs, | |||
'decoder_prompt': decoder_prompt, | |||
'source': instruction_itm, | |||
'prefix_token': prefix_token, | |||
} | |||
if 'label' in data: | |||
sample['label'] = self.label2ans[data['label']] | |||
return sample |
@@ -3,26 +3,34 @@ from typing import Any, Dict | |||
import torch | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
model_dir (str): model path | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaTextToImageSynthesisPreprocessor, | |||
self).__init__(cfg, model_dir) | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
self.max_src_length = 64 | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
source = ' '.join( | |||
data['text'].lower().strip().split()[:self.max_src_length]) | |||
source = 'what is the complete image? caption: {}'.format(source) | |||
inputs = self.get_inputs(source) | |||
inputs = self.tokenize_text(source) | |||
sample = { | |||
'source': inputs, | |||
'patch_images': None, | |||
@@ -49,11 +49,15 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) | |||
if samples[0].get('ref_dict', None) is not None: | |||
batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) | |||
if samples[0].get('label', None) is not None: | |||
batch['labels'] = np.array([s['label'] for s in samples]).tolist() | |||
if samples[0].get('constraint_mask', None) is not None: | |||
batch['constraint_masks'] = merge('constraint_mask') | |||
if samples[0].get('decoder_prompt', None) is not None: | |||
batch['decoder_prompts'] = np.array( | |||
[s['decoder_prompt'].tolist() for s in samples]) | |||
if samples[0].get('prefix_token', None) is not None: | |||
batch['prefix_tokens'] = merge('prefix_token') | |||
# For detection and visual grounding | |||
if samples[0].get('w_resize_ratio', None) is not None: | |||
batch['w_resize_ratios'] = torch.stack( | |||
@@ -0,0 +1,13 @@ | |||
from modelscope.utils.constant import Tasks | |||
OFA_TASK_KEY_MAPPING = { | |||
Tasks.ocr_recognition: ['image'], | |||
Tasks.image_captioning: ['image'], | |||
Tasks.image_classification: ['image'], | |||
Tasks.text_summarization: ['text'], | |||
Tasks.text_classification: ['text', 'text2'], | |||
Tasks.visual_grounding: ['image', 'text'], | |||
Tasks.visual_question_answering: ['image', 'text'], | |||
Tasks.visual_entailment: ['image', 'text', 'text2'], | |||
Tasks.text_to_image_synthesis: ['text'] | |||
} |
@@ -6,24 +6,33 @@ from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaVisualEntailmentPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize( | |||
(self.patch_image_size, self.patch_image_size), | |||
interpolation=transforms.InterpolationMode.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
@@ -44,7 +53,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
prompt = self.cfg.model.get( | |||
'prompt', ' can image and text1 " {} " imply text2 " {} "?') | |||
text = prompt.format(caption, hypothesis) | |||
inputs = self.get_inputs(text) | |||
inputs = self.tokenize_text(text) | |||
if self.prompt_type == 'none': | |||
decoder_prompt = self.bos_item | |||
elif self.prompt_type == 'src': | |||
@@ -6,24 +6,33 @@ from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) | |||
super(OfaVisualGroundingPreprocessor, | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize( | |||
(self.patch_image_size, self.patch_image_size), | |||
interpolation=transforms.InterpolationMode.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
@@ -39,7 +48,7 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||
prompt = self.cfg.model.get( | |||
'prompt', ' which region does the text " {} " describe?') | |||
text = prompt.format(src_caption) | |||
src_item = self.get_inputs(text) | |||
src_item = self.tokenize_text(text) | |||
sample = { | |||
'source': src_item, | |||
'patch_image': patch_image, | |||
@@ -6,25 +6,33 @@ from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import ModeKeys | |||
from .base import OfaBasePreprocessor | |||
class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
def __init__(self, cfg, model_dir): | |||
def __init__(self, | |||
cfg, | |||
model_dir, | |||
mode=ModeKeys.INFERENCE, | |||
*args, | |||
**kwargs): | |||
"""preprocess the data | |||
Args: | |||
cfg(modelscope.utils.config.ConfigDict) : model config | |||
model_dir (str): model path | |||
model_dir (str): model path, | |||
mode: preprocessor mode (model mode) | |||
""" | |||
super(OfaVisualQuestionAnsweringPreprocessor, | |||
self).__init__(cfg, model_dir) | |||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | |||
# Initialize transform | |||
self.patch_resize_transform = transforms.Compose([ | |||
lambda image: image.convert('RGB'), | |||
transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
interpolation=Image.BICUBIC), | |||
transforms.Resize( | |||
(self.patch_image_size, self.patch_image_size), | |||
interpolation=transforms.InterpolationMode.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=self.mean, std=self.std), | |||
]) | |||
@@ -34,7 +42,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
data['image'], Image.Image) else load_image(data['image']) | |||
patch_image = self.patch_resize_transform(image) | |||
text = ' {}'.format(data['text']) | |||
inputs = self.get_inputs(text) | |||
inputs = self.tokenize_text(text) | |||
if self.prompt_type == 'none': | |||
decoder_prompt = self.bos_item | |||
elif self.prompt_type == 'src': | |||
@@ -0,0 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .ofa_trainer import OFATrainer |
@@ -0,0 +1,154 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
import shutil | |||
from functools import partial | |||
from typing import Callable, Dict, Optional, Tuple, Union | |||
import torch | |||
from torch import distributed as dist | |||
from torch import nn | |||
from torch.utils.data import Dataset | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.base import Model, TorchModel | |||
from modelscope.msdatasets.ms_dataset import MsDataset | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||
from modelscope.preprocessors.ofa.utils.collate import collate_fn | |||
from modelscope.trainers import EpochBasedTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.optimizer.builder import build_optimizer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||
ModeKeys) | |||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||
get_schedule) | |||
@TRAINERS.register_module(module_name=Trainers.ofa) | |||
class OFATrainer(EpochBasedTrainer): | |||
def __init__( | |||
self, | |||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
cfg_file: Optional[str] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
data_collator: Optional[Union[Callable, Dict[str, | |||
Callable]]] = None, | |||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
preprocessor: Optional[Union[Preprocessor, | |||
Dict[str, Preprocessor]]] = None, | |||
optimizers: Tuple[torch.optim.Optimizer, | |||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||
None), | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
seed: int = 42, | |||
**kwargs): | |||
model = Model.from_pretrained(model, revision=model_revision) | |||
model_dir = model.model_dir | |||
cfg = Config.from_file(cfg_file) | |||
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: | |||
work_dir = cfg.train.work_dir | |||
else: | |||
work_dir = kwargs['work_dir'] | |||
tokenizer_files = { | |||
'zh': [ | |||
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', | |||
'config.json' | |||
], | |||
'en': | |||
['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], | |||
} | |||
for filename in tokenizer_files[cfg.model.get('language', 'en')]: | |||
finetune_file = os.path.join(work_dir, filename) | |||
pretrain_file = os.path.join(model_dir, filename) | |||
if os.path.exists(finetune_file): | |||
continue | |||
if os.path.exists(pretrain_file): | |||
shutil.copy(pretrain_file, finetune_file) | |||
if preprocessor is None: | |||
preprocessor = { | |||
ConfigKeys.train: | |||
OfaPreprocessor( | |||
model_dir=work_dir, mode=ModeKeys.TRAIN, no_collate=True), | |||
ConfigKeys.val: | |||
OfaPreprocessor( | |||
model_dir=work_dir, mode=ModeKeys.EVAL, no_collate=True), | |||
} | |||
# use torchrun launch | |||
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||
epoch_steps = math.ceil( | |||
len(train_dataset) / # noqa | |||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||
cfg.train.criterion.tokenizer = model.tokenizer | |||
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | |||
cfg.train.criterion) | |||
if optimizers[0] is None: | |||
optimizer = build_optimizer(model, cfg=cfg.train.optimizer) | |||
else: | |||
optimizer = optimizers[0] | |||
if optimizers[1] is None: | |||
scheduler_class, scheduler_args = get_schedule( | |||
cfg.train.lr_scheduler) | |||
if scheduler_class is not None: | |||
lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | |||
**scheduler_args) | |||
else: | |||
lr_scheduler = None | |||
else: | |||
lr_scheduler = optimizers[1] | |||
optimizers = (optimizer, lr_scheduler) | |||
if data_collator is None: | |||
data_collator = partial( | |||
collate_fn, | |||
pad_idx=model.tokenizer.pad_token_id, | |||
eos_idx=model.tokenizer.eos_token_id, | |||
) | |||
if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||
kwargs['launcher'] = cfg.train.launcher | |||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||
kwargs['to_tensor'] = False | |||
super().__init__( | |||
model=model, | |||
cfg_file=cfg_file, | |||
arg_parse_fn=arg_parse_fn, | |||
data_collator=data_collator, | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
preprocessor=preprocessor, | |||
optimizers=optimizers, | |||
seed=seed, | |||
**kwargs, | |||
) | |||
def train_step(self, model, inputs): | |||
model.train() | |||
model_outputs = model.forward(inputs) | |||
loss, sample_size, logging_output = self.criterion( | |||
model_outputs, inputs) | |||
train_outputs = {'loss': loss} | |||
# add model output info to log | |||
if 'log_vars' not in train_outputs: | |||
default_keys_pattern = ['loss'] | |||
match_keys = set([]) | |||
for key_p in default_keys_pattern: | |||
match_keys.update( | |||
[key for key in train_outputs.keys() if key_p in key]) | |||
log_vars = {} | |||
for key in match_keys: | |||
value = train_outputs.get(key, None) | |||
if value is not None: | |||
if dist.is_available() and dist.is_initialized(): | |||
value = value.data.clone() | |||
dist.all_reduce(value.div_(dist.get_world_size())) | |||
log_vars.update({key: value.item()}) | |||
self.log_buffer.update(log_vars) | |||
else: | |||
self.log_buffer.update(train_outputs['log_vars']) | |||
self.train_outputs = train_outputs |
@@ -0,0 +1,243 @@ | |||
# Copyright 2022 The OFA-Sys Team. | |||
# All rights reserved. | |||
# This source code is licensed under the Apache 2.0 license | |||
# found in the LICENSE file in the root directory. | |||
import math | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
import transformers | |||
from torch.nn.modules.loss import _Loss | |||
def construct_rdrop_sample(x): | |||
if isinstance(x, dict): | |||
for key in x: | |||
x[key] = construct_rdrop_sample(x[key]) | |||
return x | |||
elif isinstance(x, torch.Tensor): | |||
return x.repeat(2, *([1] * (x.dim() - 1))) | |||
elif isinstance(x, int): | |||
return x * 2 | |||
elif isinstance(x, np.ndarray): | |||
return x.repeat(2) | |||
else: | |||
raise NotImplementedError | |||
def kl_loss(p, q): | |||
p_loss = F.kl_div(p, torch.exp(q), reduction='sum') | |||
q_loss = F.kl_div(q, torch.exp(p), reduction='sum') | |||
loss = (p_loss + q_loss) / 2 | |||
return loss | |||
def label_smoothed_nll_loss(lprobs, | |||
target, | |||
epsilon, | |||
update_num, | |||
reduce=True, | |||
drop_worst_ratio=0.0, | |||
drop_worst_after=0, | |||
use_rdrop=False, | |||
reg_alpha=1.0, | |||
constraint_masks=None, | |||
constraint_start=None, | |||
constraint_end=None): | |||
if target.dim() == lprobs.dim() - 1: | |||
target = target.unsqueeze(-1) | |||
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) | |||
if constraint_masks is not None: | |||
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( | |||
dim=-1, keepdim=True).squeeze(-1) | |||
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) | |||
elif constraint_start is not None and constraint_end is not None: | |||
constraint_range = [0, 1, 2, 3] + list( | |||
range(constraint_start, constraint_end)) | |||
smooth_loss = -lprobs[:, constraint_range].sum( | |||
dim=-1, keepdim=True).squeeze(-1) | |||
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) | |||
else: | |||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) | |||
eps_i = epsilon / (lprobs.size(-1) - 1) | |||
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss | |||
if drop_worst_ratio > 0 and update_num > drop_worst_after: | |||
if use_rdrop: | |||
true_batch_size = loss.size(0) // 2 | |||
_, indices = torch.topk( | |||
loss[:true_batch_size], | |||
k=int(true_batch_size * (1 - drop_worst_ratio)), | |||
largest=False) | |||
loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) | |||
nll_loss = torch.cat( | |||
[nll_loss[indices], nll_loss[indices + true_batch_size]]) | |||
lprobs = torch.cat( | |||
[lprobs[indices], lprobs[indices + true_batch_size]]) | |||
else: | |||
loss, indices = torch.topk( | |||
loss, | |||
k=int(loss.shape[0] * (1 - drop_worst_ratio)), | |||
largest=False) | |||
nll_loss = nll_loss[indices] | |||
lprobs = lprobs[indices] | |||
ntokens = loss.numel() | |||
nll_loss = nll_loss.sum() / ntokens # 后面在grads里面处理 | |||
loss = loss.sum() / ntokens # 后面在grads里面处理 | |||
if use_rdrop: | |||
true_batch_size = lprobs.size(0) // 2 | |||
p = lprobs[:true_batch_size] | |||
q = lprobs[true_batch_size:] | |||
if constraint_start is not None and constraint_end is not None: | |||
constraint_range = [0, 1, 2, 3] + list( | |||
range(constraint_start, constraint_end)) | |||
p = p[:, constraint_range] | |||
q = q[:, constraint_range] | |||
loss += kl_loss(p, q) * reg_alpha | |||
return loss, nll_loss, ntokens | |||
class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
def __init__(self, args): | |||
super().__init__() | |||
self.sentence_avg = args.sentence_avg | |||
self.eps = args.label_smoothing | |||
self.ignore_prefix_size = args.ignore_prefix_size | |||
self.ignore_eos = args.ignore_eos | |||
self.report_accuracy = args.report_accuracy | |||
self.drop_worst_ratio = args.drop_worst_ratio | |||
self.drop_worst_after = args.drop_worst_after | |||
self.use_rdrop = args.use_rdrop | |||
self.reg_alpha = args.reg_alpha | |||
self.sample_patch_num = args.sample_patch_num | |||
self.constraint_start = None | |||
self.constraint_end = None | |||
if args.constraint_range: | |||
constraint_start, constraint_end = args.constraint_range.split(',') | |||
self.constraint_start = int(constraint_start) | |||
self.constraint_end = int(constraint_end) | |||
self.padding_idx = args.tokenizer.pad_token_id | |||
self.args = args | |||
def forward(self, output, sample, update_num=0, reduce=True): | |||
"""Compute the loss for the given sample. | |||
Returns a tuple with three elements: | |||
1) the loss | |||
2) the sample size, which is used as the denominator for the gradient | |||
3) logging outputs to display while training | |||
""" | |||
if self.use_rdrop: | |||
construct_rdrop_sample(sample) | |||
loss, nll_loss, ntokens = self.compute_loss( | |||
output, sample, update_num, reduce=reduce) | |||
sample_size = ( | |||
sample['target'].size(0) if self.sentence_avg else ntokens) | |||
logging_output = { | |||
'loss': loss.data, | |||
'nll_loss': nll_loss.data, | |||
'ntokens': sample['ntokens'], | |||
'nsentences': sample['nsentences'], | |||
'sample_size': sample_size, | |||
} | |||
return loss, sample_size, logging_output | |||
def get_lprobs_and_target(self, net_output, sample): | |||
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | |||
'conf'] is not None else 1 | |||
constraint_masks = None | |||
if 'constraint_masks' in sample and sample[ | |||
'constraint_masks'] is not None: | |||
constraint_masks = sample['constraint_masks'] | |||
net_output[0].masked_fill_(~constraint_masks, -math.inf) | |||
if self.constraint_start is not None and self.constraint_end is not None: | |||
net_output[0][:, :, 4:self.constraint_start] = -math.inf | |||
net_output[0][:, :, self.constraint_end:] = -math.inf | |||
lprobs = F.log_softmax( | |||
net_output[0], dim=-1, dtype=torch.float32) * conf | |||
target = sample['target'] | |||
if self.ignore_prefix_size > 0: | |||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | |||
target = target[:, self.ignore_prefix_size:].contiguous() | |||
if constraint_masks is not None: | |||
constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable | |||
if self.ignore_eos: | |||
bsz, seq_len, embed_dim = lprobs.size() | |||
eos_indices = target.eq(self.task.tgt_dict.eos()) | |||
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) | |||
target = target[~eos_indices].reshape(bsz, seq_len - 1) | |||
if constraint_masks is not None: | |||
constraint_masks = constraint_masks[~eos_indices].reshape( | |||
bsz, seq_len - 1, embed_dim) | |||
if constraint_masks is not None: | |||
constraint_masks = constraint_masks.view(-1, | |||
constraint_masks.size(-1)) | |||
return lprobs.view(-1, | |||
lprobs.size(-1)), target.view(-1), constraint_masks | |||
def compute_loss(self, net_output, sample, update_num, reduce=True): | |||
lprobs, target, constraint_masks = self.get_lprobs_and_target( | |||
net_output, sample) | |||
if constraint_masks is not None: | |||
constraint_masks = constraint_masks[target != self.padding_idx] | |||
lprobs = lprobs[target != self.padding_idx] | |||
target = target[target != self.padding_idx] | |||
loss, nll_loss, ntokens = label_smoothed_nll_loss( | |||
lprobs, | |||
target, | |||
self.eps, | |||
update_num, | |||
reduce=reduce, | |||
drop_worst_ratio=self.drop_worst_ratio, | |||
drop_worst_after=self.drop_worst_after, | |||
use_rdrop=self.use_rdrop, | |||
reg_alpha=self.reg_alpha, | |||
constraint_masks=constraint_masks, | |||
constraint_start=self.constraint_start, | |||
constraint_end=self.constraint_end) | |||
return loss, nll_loss, ntokens | |||
def get_schedule(scheduler): | |||
if scheduler.name == 'const': | |||
scheduler_class = transformers.get_constant_schedule_with_warmup | |||
scheduler_args = { | |||
'num_warmup_steps': | |||
int(scheduler.warmup_proportion * scheduler.num_train_steps) | |||
} | |||
elif scheduler.name == 'linear': | |||
scheduler_class = transformers.get_linear_schedule_with_warmup | |||
scheduler_args = { | |||
'num_warmup_steps': | |||
int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||
'num_training_steps': | |||
scheduler.num_train_steps | |||
} | |||
elif scheduler.name == 'cosine': | |||
scheduler_class = transformers.get_cosine_schedule_with_warmup | |||
scheduler_args = { | |||
'num_warmup_steps': | |||
int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||
'num_training_steps': | |||
scheduler.num_train_steps | |||
} | |||
elif scheduler.name == 'polynomial_decay': | |||
scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup | |||
scheduler_args = { | |||
'num_warmup_steps': | |||
int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||
'num_training_steps': | |||
scheduler.num_train_steps, | |||
'lr_end': | |||
scheduler.lr_end | |||
} | |||
else: | |||
raise NotImplementedError | |||
return scheduler_class, scheduler_args |
@@ -168,19 +168,20 @@ class EpochBasedTrainer(BaseTrainer): | |||
device_name = f'cuda:{local_rank}' | |||
self.device = create_device(device_name) | |||
self.train_dataset = self.to_task_dataset( | |||
train_dataset, | |||
mode=ModeKeys.TRAIN, | |||
task_data_config=self.cfg.dataset.get('train', None) if hasattr( | |||
self.cfg, 'dataset') else None, | |||
preprocessor=self.train_preprocessor) | |||
preprocessor=self.train_preprocessor, | |||
**kwargs) | |||
self.eval_dataset = self.to_task_dataset( | |||
eval_dataset, | |||
mode=ModeKeys.EVAL, | |||
task_data_config=self.cfg.dataset.get('val', None) if hasattr( | |||
self.cfg, 'dataset') else None, | |||
preprocessor=self.eval_preprocessor) | |||
preprocessor=self.eval_preprocessor, | |||
**kwargs) | |||
self.train_data_collator, self.eval_default_collate = None, None | |||
if isinstance(data_collator, Mapping): | |||
@@ -216,7 +217,6 @@ class EpochBasedTrainer(BaseTrainer): | |||
self._max_epochs = self.cfg.train.max_epochs | |||
else: | |||
self._max_epochs = kwargs['max_epochs'] | |||
self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) | |||
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) | |||
if self._train_iters_per_epoch is None and hasattr( | |||
@@ -306,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): | |||
datasets: Union[Dataset, List[Dataset]], | |||
mode: str, | |||
task_data_config: Config = None, | |||
preprocessor: Optional[Preprocessor] = None): | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""Build the task specific dataset processor for this trainer. | |||
Returns: The task dataset processor for the task. If no result for the very model-type and task, | |||
the default TaskDataset will be returned. | |||
""" | |||
try: | |||
to_tensor = kwargs.get('to_tensor', True) | |||
if not datasets: | |||
return datasets | |||
if isinstance(datasets, TorchTaskDataset): | |||
@@ -328,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
return datasets.to_torch_dataset( | |||
task_data_config=task_data_config, | |||
task_name=self.cfg.task, | |||
preprocessors=preprocessor) | |||
preprocessors=preprocessor, | |||
to_tensor=to_tensor) | |||
elif isinstance(datasets, List) and isinstance( | |||
datasets[0], MsDataset): | |||
if task_data_config is None: | |||
@@ -342,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
d.to_torch_dataset( | |||
task_data_config=task_data_config, | |||
task_name=self.cfg.task, | |||
preprocessors=preprocessor) for d in datasets | |||
preprocessors=preprocessor, | |||
to_tensor=to_tensor) for d in datasets | |||
] | |||
cfg = ConfigDict( | |||
type=self.cfg.model.type, mode=mode, datasets=datasets) | |||
@@ -497,6 +501,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
dp_cfg = dict( | |||
type='DistributedDataParallel', | |||
module=model, | |||
find_unused_parameters=True, | |||
device_ids=[torch.cuda.current_device()]) | |||
return build_parallel(dp_cfg) | |||
@@ -779,7 +784,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
batch_size = batch_size_per_gpu | |||
num_workers = workers_per_gpu | |||
if dist: | |||
if dist and not isinstance(dataset, torch.utils.data.IterableDataset): | |||
sampler = DistributedSampler( | |||
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | |||
else: | |||
@@ -69,7 +69,10 @@ def single_gpu_test(model, | |||
batch_size = 1 # iteration count | |||
else: | |||
if isinstance(data, dict): | |||
batch_size = len(next(iter(data.values()))) | |||
if 'nsentences' in data: | |||
batch_size = data['nsentences'] | |||
else: | |||
batch_size = len(next(iter(data.values()))) | |||
else: | |||
batch_size = len(data) | |||
for _ in range(batch_size): | |||
@@ -152,21 +155,29 @@ def multi_gpu_test(model, | |||
result = model.forward(data) | |||
results.append(result) | |||
if rank == 0: | |||
if isinstance(data, dict): | |||
batch_size = len(next(iter(data.values()))) | |||
if isinstance(data, dict): | |||
if 'nsentences' in data: | |||
batch_size = data['nsentences'] | |||
else: | |||
batch_size = len(data) | |||
if progress_with_iters: | |||
total_samples += batch_size * world_size | |||
batch_size = 1 # iteration count | |||
batch_size = len(next(iter(data.values()))) | |||
else: | |||
batch_size = len(data) | |||
if i >= (data_len // world_size) - 1: | |||
total_samples = torch.LongTensor([batch_size]).to(model.device) | |||
dist.all_reduce(total_samples, op=dist.reduce_op.SUM) | |||
total_samples = total_samples.item() | |||
else: | |||
total_samples = batch_size * world_size | |||
if progress_with_iters: | |||
iter_cnt_all = world_size | |||
else: | |||
iter_cnt_all = total_samples | |||
count += iter_cnt_all | |||
batch_size_all = batch_size * world_size | |||
count += batch_size_all | |||
if rank == 0: | |||
if count > data_len: | |||
batch_size_all = data_len - (count - batch_size_all) | |||
for _ in range(batch_size_all): | |||
iter_cnt_all = data_len - (count - iter_cnt_all) | |||
for _ in range(iter_cnt_all): | |||
pbar.update() | |||
if progress_with_iters and (i + 1) >= data_len: | |||
@@ -280,6 +280,7 @@ class ConfigKeys(object): | |||
"""Fixed keywords in configuration file""" | |||
train = 'train' | |||
val = 'val' | |||
test = 'test' | |||
class Requirements(object): | |||
@@ -1,5 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from contextlib import contextmanager | |||
from modelscope.utils.constant import Devices, Frameworks | |||
@@ -106,3 +106,17 @@ def create_device(device_name): | |||
device = torch.device('cpu') | |||
return device | |||
def get_device(): | |||
import torch | |||
from torch import distributed as dist | |||
if torch.cuda.is_available(): | |||
if dist.is_available() and dist.is_initialized( | |||
) and 'LOCAL_RANK' in os.environ: | |||
device_id = f"cuda:{os.environ['LOCAL_RANK']}" | |||
else: | |||
device_id = 'cuda:0' | |||
else: | |||
device_id = 'cpu' | |||
return torch.device(device_id) |
@@ -0,0 +1,14 @@ | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from .fp16 import FP16_Module, FP16_Optimizer |
@@ -0,0 +1,655 @@ | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Stable version of apex FP16 Optimizer""" | |||
import torch | |||
from torch import nn | |||
from torch.autograd import Variable | |||
from torch.nn.parameter import Parameter | |||
from .fp16util import (master_params_to_model_params, | |||
model_grads_to_master_grads) | |||
from .loss_scaler import DynamicLossScaler, LossScaler | |||
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||
def conversion_helper(val, conversion): | |||
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||
if not isinstance(val, (tuple, list)): | |||
return conversion(val) | |||
rtn = [conversion_helper(v, conversion) for v in val] | |||
if isinstance(val, tuple): | |||
rtn = tuple(rtn) | |||
return rtn | |||
def fp32_to_fp16(val): | |||
"""Convert fp32 `val` to fp16""" | |||
def half_conversion(val): | |||
val_typecheck = val | |||
if isinstance(val_typecheck, (Parameter, Variable)): | |||
val_typecheck = val.data | |||
if isinstance(val_typecheck, FLOAT_TYPES): | |||
val = val.half() | |||
return val | |||
return conversion_helper(val, half_conversion) | |||
def fp16_to_fp32(val): | |||
"""Convert fp16 `val` to fp32""" | |||
def float_conversion(val): | |||
val_typecheck = val | |||
if isinstance(val_typecheck, (Parameter, Variable)): | |||
val_typecheck = val.data | |||
if isinstance(val_typecheck, HALF_TYPES): | |||
val = val.float() | |||
return val | |||
return conversion_helper(val, float_conversion) | |||
class FP16_Module(nn.Module): | |||
def __init__(self, module): | |||
super(FP16_Module, self).__init__() | |||
self.add_module('module', module.half()) | |||
def forward(self, *inputs, **kwargs): | |||
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||
def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
return self.module.state_dict(destination, prefix, keep_vars) | |||
def load_state_dict(self, state_dict, strict=True): | |||
self.module.load_state_dict(state_dict, strict=strict) | |||
class FP16_Optimizer(object): | |||
""" | |||
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, | |||
and manage static or dynamic loss scaling and master weights in a manner transparent to the user. | |||
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, | |||
and changing the call to ``backward``. | |||
Example:: | |||
model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
# Name the FP16_Optimizer instance to replace the existing optimizer | |||
# (recommended but not required): | |||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
... | |||
# loss.backward() becomes: | |||
optimizer.backward(loss) | |||
... | |||
Example with dynamic loss scaling:: | |||
... | |||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) | |||
# optional arg to control dynamic loss scaling behavior | |||
# dynamic_loss_args={'scale_window' : 500}) | |||
# Usually, dynamic_loss_args is not necessary. | |||
Args: | |||
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa | |||
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa | |||
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa | |||
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa | |||
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa | |||
``init_optimizer`` is expected to have been constructed in the ordinary way. | |||
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be | |||
named to replace ``init_optimizer``, for two reasons: | |||
First, it means that references to the same name | |||
later in the file will not have to change. | |||
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to | |||
modify ``init_optimizer``. If you do choose a unique name for the new | |||
:class:`FP16_Optimizer` instance, you should only work with this new instance, | |||
because the preexisting optimizer might no longer behave as expected. | |||
``init_optimizer`` may be any Pytorch optimizer. | |||
It may contain a mixture of fp16 and fp32 parameters organized into any number of | |||
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will | |||
ingest these ``param_groups`` and remember them. | |||
Calls to :: | |||
loss.backward() | |||
must be replaced with :: | |||
optimizer.backward(loss) | |||
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement | |||
loss scaling and copies to master gradients. | |||
.. note:: | |||
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients | |||
are downscaled before being applied. This means that adjusting the loss scale, or using | |||
dynamic loss scaling, should not require retuning the learning rate or any other | |||
hyperparameters. | |||
**Advanced options** | |||
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. | |||
See docstring for :attr:`step`. | |||
**Gradient clipping**: Use :attr:`clip_master_grads`. | |||
**Multiple losses**: If your model accumulates gradients from multiple losses, | |||
this can be made more efficient by supplying ``update_master_grads=False`` | |||
to :attr:`backward`. See docstring for :attr:`backward`. | |||
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: | |||
print(optimizer.loss_scale) | |||
optimizer.loss_scale = new_loss_scale | |||
For static loss scaling, manually adjusting the loss scale over time is a reasonable | |||
thing to do. During later epochs, gradients may become smaller, and a | |||
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss | |||
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting | |||
the loss scale is not recommended. | |||
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in | |||
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` | |||
should still work as intended. | |||
""" | |||
def __init__(self, | |||
init_optimizer, | |||
static_loss_scale=1.0, | |||
dynamic_loss_scale=False, | |||
dynamic_loss_args=None, | |||
verbose=False): | |||
if not torch.cuda.is_available: | |||
raise SystemError('Cannot use fp16 without CUDA.') | |||
self.verbose = verbose | |||
self.optimizer = init_optimizer | |||
# init_state_dict sets up an alternative way to cast per-param state tensors. | |||
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. | |||
# init_state_dict = init_optimizer.state_dict() | |||
self.fp16_groups = [] | |||
self.fp32_from_fp16_groups = [] | |||
self.fp32_from_fp32_groups = [] | |||
for i, param_group in enumerate(self.optimizer.param_groups): | |||
self.maybe_print( | |||
'FP16_Optimizer processing param group {}:'.format(i)) | |||
fp16_params_this_group = [] | |||
fp32_params_this_group = [] | |||
fp32_from_fp16_params_this_group = [] | |||
for i, param in enumerate(param_group['params']): | |||
if param.requires_grad: | |||
if param.type() == 'torch.cuda.HalfTensor': | |||
self.maybe_print( | |||
'FP16_Optimizer received torch.cuda.HalfTensor with {}' | |||
.format(param.size())) | |||
fp16_params_this_group.append(param) | |||
master_param = param.detach().clone().float() | |||
master_param.requires_grad = True | |||
# Copythe model parallel flag. | |||
master_param.model_parallel = param.model_parallel | |||
param_group['params'][i] = master_param | |||
fp32_from_fp16_params_this_group.append(master_param) | |||
# Reset existing state dict key to the new master param. | |||
# We still need to recast per-param state tensors, if any, to FP32. | |||
if param in self.optimizer.state: | |||
self.optimizer.state[ | |||
master_param] = self.optimizer.state.pop(param) | |||
elif param.type() == 'torch.cuda.FloatTensor': | |||
self.maybe_print( | |||
'FP16_Optimizer received torch.cuda.FloatTensor with {}' | |||
.format(param.size())) | |||
fp32_params_this_group.append(param) | |||
param_group['params'][i] = param | |||
else: | |||
raise TypeError( | |||
'Wrapped parameters must be either ' | |||
'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' | |||
'Received {}'.format(param.type())) | |||
self.fp16_groups.append(fp16_params_this_group) | |||
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) | |||
self.fp32_from_fp32_groups.append(fp32_params_this_group) | |||
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors | |||
self.optimizer.load_state_dict(self.optimizer.state_dict()) | |||
# alternative way to cast per-param state tensors: | |||
# self.optimizer.load_state_dict(init_state_dict) | |||
if dynamic_loss_scale: | |||
self.dynamic_loss_scale = True | |||
if dynamic_loss_args is not None: | |||
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) | |||
else: | |||
self.loss_scaler = DynamicLossScaler() | |||
else: | |||
self.dynamic_loss_scale = False | |||
self.loss_scaler = LossScaler(static_loss_scale) | |||
self.overflow = False | |||
self.first_closure_call_this_step = True | |||
self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ | |||
def maybe_print(self, msg): | |||
if self.verbose: | |||
print(msg) | |||
def __getstate__(self): | |||
raise RuntimeError( | |||
'FP16_Optimizer should be serialized using state_dict().') | |||
def __setstate__(self, state): | |||
raise RuntimeError( | |||
'FP16_Optimizer should be deserialized using load_state_dict().') | |||
def zero_grad(self, set_grads_to_None=False): | |||
""" | |||
Zero fp32 and fp16 parameter grads. | |||
""" | |||
# In principle, only the .grad attributes of the model params need to be zeroed, | |||
# because gradients are copied into the FP32 master params. However, we zero | |||
# all gradients owned by the optimizer, just to be safe: | |||
for group in self.optimizer.param_groups: | |||
for p in group['params']: | |||
if set_grads_to_None: | |||
p.grad = None | |||
else: | |||
if p.grad is not None: | |||
p.grad.detach_() | |||
p.grad.zero_() | |||
# Zero fp16 gradients owned by the model: | |||
for fp16_group in self.fp16_groups: | |||
for param in fp16_group: | |||
if set_grads_to_None: | |||
param.grad = None | |||
else: | |||
if param.grad is not None: | |||
param.grad.detach_( | |||
) # as in torch.optim.optimizer.zero_grad() | |||
param.grad.zero_() | |||
def _check_overflow(self): | |||
params = [] | |||
for group in self.fp16_groups: | |||
for param in group: | |||
params.append(param) | |||
for group in self.fp32_from_fp32_groups: | |||
for param in group: | |||
params.append(param) | |||
self.overflow = self.loss_scaler.has_overflow(params) | |||
def _update_scale(self, has_overflow=False): | |||
self.loss_scaler.update_scale(has_overflow) | |||
def _master_params_to_model_params(self): | |||
for fp16_group, fp32_from_fp16_group in zip( | |||
self.fp16_groups, self.fp32_from_fp16_groups): | |||
master_params_to_model_params(fp16_group, fp32_from_fp16_group) | |||
def _model_params_to_master_params(self): | |||
for fp16_group, fp32_from_fp16_group in zip( | |||
self.fp16_groups, self.fp32_from_fp16_groups): | |||
master_params_to_model_params(fp32_from_fp16_group, fp16_group) | |||
# To consider: Integrate distributed with this wrapper by registering a hook on each variable | |||
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. | |||
def _model_grads_to_master_grads(self): | |||
for fp16_group, fp32_from_fp16_group in zip( | |||
self.fp16_groups, self.fp32_from_fp16_groups): | |||
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) | |||
def _downscale_master(self): | |||
if self.loss_scale != 1.0: | |||
for group in self.optimizer.param_groups: | |||
for param in group['params']: | |||
if param.grad is not None: | |||
param.grad.data.mul_(1. / self.loss_scale) | |||
def clip_master_grads(self, max_norm, norm_type=2): | |||
""" | |||
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. | |||
Args: | |||
max_norm (float or int): max norm of the gradients | |||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||
infinity norm. | |||
Returns: | |||
Total norm of the current fp32 gradients (viewed as a single vector). | |||
.. warning:: | |||
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa | |||
""" | |||
if not self.overflow: | |||
fp32_params = [] | |||
for param_group in self.optimizer.param_groups: | |||
for param in param_group['params']: | |||
fp32_params.append(param) | |||
return self.clip_grad_norm(fp32_params, max_norm, norm_type) | |||
else: | |||
return -1 | |||
def state_dict(self): | |||
""" | |||
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |||
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |||
of the contained Pytorch optimizer. | |||
Example:: | |||
checkpoint = {} | |||
checkpoint['model'] = model.state_dict() | |||
checkpoint['optimizer'] = optimizer.state_dict() | |||
torch.save(checkpoint, "saved.pth") | |||
""" | |||
state_dict = {} | |||
state_dict['loss_scaler'] = self.loss_scaler | |||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |||
state_dict['overflow'] = self.overflow | |||
state_dict[ | |||
'first_closure_call_this_step'] = self.first_closure_call_this_step | |||
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |||
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups | |||
return state_dict | |||
def load_state_dict(self, state_dict): | |||
""" | |||
Loads a state_dict created by an earlier call to state_dict(). | |||
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |||
whose parameters in turn came from ``model``, it is expected that the user | |||
will call ``model.load_state_dict()`` before | |||
``fp16_optimizer_instance.load_state_dict()`` is called. | |||
Example:: | |||
model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
... | |||
checkpoint = torch.load("saved.pth") | |||
model.load_state_dict(checkpoint['model']) | |||
optimizer.load_state_dict(checkpoint['optimizer']) | |||
""" | |||
# I think it should actually be ok to reload the optimizer before the model. | |||
self.loss_scaler = state_dict['loss_scaler'] | |||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |||
self.overflow = state_dict['overflow'] | |||
self.first_closure_call_this_step = state_dict[ | |||
'first_closure_call_this_step'] | |||
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |||
# At this point, the optimizer's references to the model's fp32 parameters are up to date. | |||
# The optimizer's hyperparameters and internal buffers are also up to date. | |||
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |||
# out of date. There are two options. | |||
# 1: Refresh the master params from the model's fp16 params. | |||
# This requires less storage but incurs precision loss. | |||
# 2: Save and restore the fp32 master copies separately. | |||
# We choose option 2. | |||
# | |||
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |||
# of their associated parameters, because it's possible those buffers might not exist yet in | |||
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |||
# constructed in the same way as the one whose state_dict we are loading, the same master params | |||
# are guaranteed to exist, so we can just copy_() from the saved master params. | |||
for current_group, saved_group in zip(self.fp32_from_fp16_groups, | |||
state_dict['fp32_from_fp16']): | |||
for current, saved in zip(current_group, saved_group): | |||
current.data.copy_(saved.data) | |||
def step(self, closure=None): # could add clip option. | |||
""" | |||
If no closure is supplied, :attr:`step` should be called after | |||
``fp16_optimizer_obj.backward(loss)``. | |||
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to | |||
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params | |||
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run | |||
another forward pass using their model. | |||
If a closure is supplied, :attr:`step` may be called without a prior call to | |||
:attr:`backward(loss)`. | |||
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. | |||
However, the user should take care that any ``loss.backward()`` call within the closure | |||
has been replaced by ``fp16_optimizer_obj.backward(loss)``. | |||
Args: | |||
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa | |||
Example with closure:: | |||
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an | |||
# existing pytorch optimizer. | |||
for input, target in dataset: | |||
def closure(): | |||
optimizer.zero_grad() | |||
output = model(input) | |||
loss = loss_fn(output, target) | |||
# loss.backward() becomes: | |||
optimizer.backward(loss) | |||
return loss | |||
optimizer.step(closure) | |||
.. warning:: | |||
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. | |||
.. _`ordinary Pytorch optimizer use`: | |||
http://pytorch.org/docs/master/optim.html#optimizer-step-closure | |||
""" | |||
scale = self.loss_scaler.loss_scale | |||
self._update_scale(self.overflow) | |||
if self.overflow: | |||
self.maybe_print( | |||
'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' | |||
.format(scale, self.loss_scale)) | |||
return | |||
if closure is not None: | |||
retval = self._step_with_closure(closure) | |||
else: | |||
retval = self.optimizer.step() | |||
self._master_params_to_model_params() | |||
return retval | |||
def _step_with_closure(self, closure): | |||
def wrapped_closure(): | |||
# helpful for debugging | |||
# print("Calling wrapped_closure, first_closure_call_this_step = {}" | |||
# .format(self.first_closure_call_this_step)) | |||
if self.first_closure_call_this_step: | |||
# We expect that the fp16 params are initially fresh on entering self.step(), | |||
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure() | |||
# is called within self.optimizer.step(). | |||
self.first_closure_call_this_step = False | |||
else: | |||
# If self.optimizer.step() internally calls wrapped_closure more than once, | |||
# it may update the fp32 params after each call. However, self.optimizer | |||
# doesn't know about the fp16 params at all. If the fp32 params get updated, | |||
# we can't rely on self.optimizer to refresh the fp16 params. We need | |||
# to handle that manually: | |||
self._master_params_to_model_params() | |||
# Our API expects the user to give us ownership of the backward() call by | |||
# replacing all calls to loss.backward() with optimizer.backward(loss). | |||
# This requirement holds whether or not the call to backward() is made within a closure. | |||
# If the user is properly calling optimizer.backward(loss) within "closure," | |||
# calling closure() here will give the fp32 master params fresh gradients | |||
# for the optimizer to play with, so all wrapped_closure needs to do is call | |||
# closure() and return the loss. | |||
temp_loss = closure() | |||
while (self.overflow): | |||
scale = self.loss_scaler.loss_scale | |||
self._update_scale(self.overflow) | |||
self.maybe_print( | |||
'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' | |||
'reducing to {}'.format(scale, self.loss_scale)) | |||
temp_loss = closure() | |||
return temp_loss | |||
retval = self.optimizer.step(wrapped_closure) | |||
self.first_closure_call_this_step = True | |||
return retval | |||
def backward(self, loss, update_master_grads=True, retain_graph=False): | |||
""" | |||
:attr:`backward` performs the following conceptual steps: | |||
1. fp32_loss = loss.float() (see first Note below) | |||
2. scaled_loss = fp32_loss*loss_scale | |||
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa | |||
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa | |||
5. Finally, master grads are divided by loss_scale. | |||
In this way, after :attr:`backward`, the master params have fresh gradients, | |||
and :attr:`step` may be called. | |||
.. note:: | |||
:attr:`backward` internally converts the loss to fp32 before applying the loss scale. | |||
This provides some additional safety against overflow if the user has supplied an | |||
fp16 loss value. | |||
However, for maximum overflow safety, the user should | |||
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to | |||
:attr:`backward`. | |||
.. warning:: | |||
The gradients found in a model's leaves after the call to | |||
:attr:`backward` should not be regarded as valid in general, | |||
because it's possible | |||
they have been scaled (and in the case of dynamic loss scaling, | |||
the scale factor may change over time). | |||
If the user wants to inspect gradients after a call to :attr:`backward`, | |||
only the master gradients should be regarded as valid. These can be retrieved via | |||
:attr:`inspect_master_grad_data()`. | |||
Args: | |||
loss: The loss output by the user's model. loss may be either float or half (but see first Note above). | |||
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa | |||
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa | |||
Example:: | |||
# Ordinary operation: | |||
optimizer.backward(loss) | |||
# Naive operation with multiple losses (technically valid, but less efficient): | |||
# fp32 grads will be correct after the second call, but | |||
# the first call incurs an unnecessary fp16->fp32 grad copy. | |||
optimizer.backward(loss1) | |||
optimizer.backward(loss2) | |||
# More efficient way to handle multiple losses: | |||
# The fp16->fp32 grad copy is delayed until fp16 grads from all | |||
# losses have been accumulated. | |||
optimizer.backward(loss1, update_master_grads=False) | |||
optimizer.backward(loss2, update_master_grads=False) | |||
optimizer.update_master_grads() | |||
""" | |||
# To consider: try multiple backward passes using retain_grad=True to find | |||
# a loss scale that works. After you find a loss scale that works, do a final dummy | |||
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid | |||
# discarding the iteration, but probably wouldn't improve overall efficiency. | |||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) | |||
if update_master_grads: | |||
self.update_master_grads() | |||
def update_master_grads(self): | |||
""" | |||
Copy the ``.grad`` attribute from stored references to fp16 parameters to | |||
the ``.grad`` attribute of the fp32 master parameters that are directly | |||
updated by the optimizer. :attr:`update_master_grads` only needs to be called if | |||
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. | |||
""" | |||
if self.dynamic_loss_scale: | |||
self._check_overflow() | |||
if self.overflow: return # noqa | |||
self._model_grads_to_master_grads() | |||
self._downscale_master() | |||
def inspect_master_grad_data(self): | |||
""" | |||
When running with :class:`FP16_Optimizer`, | |||
``.grad`` attributes of a model's fp16 leaves should not be | |||
regarded as truthful, because they might be scaled. | |||
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, | |||
the fp32 master params' ``.grad`` | |||
attributes will contain valid gradients properly divided by the loss scale. However, | |||
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be | |||
nonintuitive. :attr:`inspect_master_grad_data` | |||
allows those gradients to be viewed with shapes corresponding to their associated model leaves. | |||
Returns: | |||
List of lists (one list for each parameter group). The list for each parameter group | |||
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. | |||
""" | |||
if self.overflow: | |||
print( | |||
'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' | |||
'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' | |||
) | |||
return None | |||
else: | |||
# The optimizer owns only references to master params. | |||
master_grads_data = [] | |||
for param_group in self.optimizer.param_groups: | |||
master_grads_this_group = [] | |||
for param in param_group['params']: | |||
if param.grad is not None: | |||
master_grads_this_group.append(param.grad.data) | |||
else: | |||
master_grads_this_group.append(None) | |||
master_grads_data.append(master_grads_this_group) | |||
return master_grads_data | |||
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" | |||
def _get_loss_scale(self): | |||
return self.loss_scaler.loss_scale | |||
def _set_loss_scale(self, value): | |||
self.loss_scaler.cur_scale = value | |||
loss_scale = property(_get_loss_scale, _set_loss_scale) | |||
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" | |||
def _get_state(self): | |||
return self.optimizer.state | |||
def _set_state(self, value): | |||
self.optimizer.state = value | |||
state = property(_get_state, _set_state) | |||
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" | |||
# (for example, to adjust the learning rate) | |||
def _get_param_groups(self): | |||
return self.optimizer.param_groups | |||
def _set_param_groups(self, value): | |||
self.optimizer.param_groups = value | |||
param_groups = property(_get_param_groups, _set_param_groups) |
@@ -0,0 +1,216 @@ | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
from torch.autograd import Variable | |||
class tofp16(nn.Module): | |||
""" | |||
Utility module that implements:: | |||
def forward(self, input): | |||
return input.half() | |||
""" | |||
def __init__(self): | |||
super(tofp16, self).__init__() | |||
def forward(self, input): | |||
return input.half() | |||
def BN_convert_float(module): | |||
""" | |||
Utility function for network_to_half(). | |||
Retained for legacy purposes. | |||
""" | |||
if isinstance( | |||
module, | |||
torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: | |||
module.float() | |||
for child in module.children(): | |||
BN_convert_float(child) | |||
return module | |||
def network_to_half(network): | |||
""" | |||
Convert model to half precision in a batchnorm-safe way. | |||
Retained for legacy purposes. It is recommended to use FP16Model. | |||
""" | |||
return nn.Sequential(tofp16(), BN_convert_float(network.half())) | |||
def convert_module(module, dtype): | |||
""" | |||
Converts a module's immediate parameters and buffers to dtype. | |||
""" | |||
for param in module.parameters(recurse=False): | |||
if param is not None: | |||
if param.data.dtype.is_floating_point: | |||
param.data = param.data.to(dtype=dtype) | |||
if param._grad is not None and param._grad.data.dtype.is_floating_point: | |||
param._grad.data = param._grad.data.to(dtype=dtype) | |||
for buf in module.buffers(recurse=False): | |||
if buf is not None and buf.data.dtype.is_floating_point: | |||
buf.data = buf.data.to(dtype=dtype) | |||
def convert_network(network, dtype): | |||
""" | |||
Converts a network's parameters and buffers to dtype. | |||
""" | |||
for module in network.modules(): | |||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm | |||
) and module.affine is True: | |||
continue | |||
convert_module(module, dtype) | |||
return network | |||
class FP16Model(nn.Module): | |||
""" | |||
Convert model to half precision in a batchnorm-safe way. | |||
""" | |||
def __init__(self, network): | |||
super(FP16Model, self).__init__() | |||
self.network = convert_network(network, dtype=torch.half) | |||
def forward(self, *inputs): | |||
inputs = tuple(t.half() for t in inputs) | |||
return self.network(*inputs) | |||
def backwards_debug_hook(grad): | |||
raise RuntimeError( | |||
'master_params recieved a gradient in the backward pass!') | |||
def prep_param_lists(model, flat_master=False): | |||
""" | |||
Creates a list of FP32 master parameters for a given model, as in | |||
`Training Neural Networks with Mixed Precision: Real Examples`_. | |||
Args: | |||
model (torch.nn.Module): Existing Pytorch model | |||
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa | |||
Returns: | |||
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa | |||
Example:: | |||
model_params, master_params = prep_param_lists(model) | |||
.. warning:: | |||
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa | |||
.. _`Training Neural Networks with Mixed Precision: Real Examples`: | |||
http://on-demand.gputechconf.com/gtc/2018/video/S81012/ | |||
""" | |||
model_params = [ | |||
param for param in model.parameters() if param.requires_grad | |||
] | |||
if flat_master: | |||
# Give the user some more useful error messages | |||
try: | |||
# flatten_dense_tensors returns a contiguous flat array. | |||
# http://pytorch.org/docs/master/_modules/torch/_utils.html | |||
master_params = _flatten_dense_tensors( | |||
[param.data for param in model_params]).float() | |||
except: # noqa | |||
print( | |||
'Error in prep_param_lists: model may contain a mixture of parameters ' | |||
'of different types. Use flat_master=False, or use F16_Optimizer.' | |||
) | |||
raise | |||
master_params = torch.nn.Parameter(master_params) | |||
master_params.requires_grad = True | |||
# master_params.register_hook(backwards_debug_hook) | |||
if master_params.grad is None: | |||
master_params.grad = master_params.new(*master_params.size()) | |||
return model_params, [master_params] | |||
else: | |||
master_params = [ | |||
param.clone().float().detach() for param in model_params | |||
] | |||
for param in master_params: | |||
param.requires_grad = True | |||
return model_params, master_params | |||
def model_grads_to_master_grads(model_params, | |||
master_params, | |||
flat_master=False): | |||
""" | |||
Copy model gradients to master gradients. | |||
Args: | |||
model_params: List of model parameters created by :func:`prep_param_lists`. | |||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa | |||
""" | |||
if flat_master: | |||
# The flattening may incur one more deep copy than is necessary. | |||
master_params[0].grad.data.copy_( | |||
_flatten_dense_tensors([p.grad.data for p in model_params])) | |||
else: | |||
for model, master in zip(model_params, master_params): | |||
if model.grad is not None: | |||
if master.grad is None: | |||
master.grad = Variable( | |||
master.data.new(*master.data.size())) | |||
master.grad.data.copy_(model.grad.data) | |||
else: | |||
master.grad = None | |||
def master_params_to_model_params(model_params, | |||
master_params, | |||
flat_master=False): | |||
""" | |||
Copy master parameters to model parameters. | |||
Args: | |||
model_params: List of model parameters created by :func:`prep_param_lists`. | |||
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa | |||
""" | |||
if flat_master: | |||
for model, master in zip( | |||
model_params, | |||
_unflatten_dense_tensors(master_params[0].data, model_params)): | |||
model.data.copy_(master) | |||
else: | |||
for model, master in zip(model_params, master_params): | |||
model.data.copy_(master.data) | |||
# Backward compatibility fixes | |||
def to_python_float(t): | |||
if hasattr(t, 'item'): | |||
return t.item() | |||
else: | |||
return t[0] | |||
TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
@@ -0,0 +1,237 @@ | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
# item() is a recent addition, so this helps with backward compatibility. | |||
def to_python_float(t): | |||
if hasattr(t, 'item'): | |||
return t.item() | |||
else: | |||
return t[0] | |||
class LossScaler: | |||
""" | |||
Class that manages a static loss scale. This class is intended to interact with | |||
:class:`FP16_Optimizer`, and should not be directly manipulated by the user. | |||
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to | |||
:class:`FP16_Optimizer`'s constructor. | |||
Args: | |||
scale (float, optional, default=1.0): The loss scale. | |||
""" | |||
def __init__(self, scale=1): | |||
self.cur_scale = scale | |||
# `params` is a list / generator of torch.Variable | |||
def has_overflow(self, params): | |||
return False | |||
# `x` is a torch.Tensor | |||
def _has_inf_or_nan(x): | |||
return False | |||
def update_scale(self, overflow): | |||
pass | |||
@property | |||
def loss_scale(self): | |||
return self.cur_scale | |||
def scale_gradient(self, module, grad_in, grad_out): | |||
return tuple(self.loss_scale * g for g in grad_in) | |||
def backward(self, loss, retain_graph=False): | |||
scaled_loss = loss * self.loss_scale | |||
scaled_loss.backward(retain_graph=retain_graph) | |||
class DynamicLossScaler: | |||
""" | |||
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` | |||
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of | |||
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` | |||
operates, because the default options can be changed using the | |||
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. | |||
Loss scaling is designed to combat the problem of underflowing gradients encountered at long | |||
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss | |||
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are | |||
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has | |||
occurred. | |||
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, | |||
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. | |||
If a certain number of iterations occur without overflowing gradients detected, | |||
:class:`DynamicLossScaler` increases the loss scale once more. | |||
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of | |||
always using the highest loss scale possible without incurring overflow. | |||
Args: | |||
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` | |||
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa | |||
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa | |||
""" | |||
def __init__(self, | |||
init_scale=2**32, | |||
scale_factor=2., | |||
scale_window=1000, | |||
min_scale=1, | |||
delayed_shift=1, | |||
consecutive_hysteresis=False): | |||
self.cur_scale = init_scale | |||
self.cur_iter = 0 | |||
self.last_overflow_iter = -1 | |||
self.scale_factor = scale_factor | |||
self.scale_window = scale_window | |||
self.min_scale = min_scale | |||
self.delayed_shift = delayed_shift | |||
self.cur_hysteresis = delayed_shift | |||
self.consecutive_hysteresis = consecutive_hysteresis | |||
# `params` is a list / generator of torch.Variable | |||
def has_overflow_serial(self, params): | |||
for p in params: | |||
if p.grad is not None and DynamicLossScaler._has_inf_or_nan( | |||
p.grad.data): | |||
return True | |||
return False | |||
def has_overflow(self, params): | |||
overflow = self.has_overflow_serial(params) | |||
overflow_gpu = torch.cuda.ByteTensor([overflow]) | |||
overflow = overflow_gpu[0].item() | |||
return bool(overflow) | |||
# `x` is a torch.Tensor | |||
def _has_inf_or_nan(x): | |||
try: | |||
# if x is half, the .float() incurs an additional deep copy, but it's necessary if | |||
# Pytorch's .sum() creates a one-element tensor of the same type as x | |||
# (which is true for some recent version of pytorch). | |||
cpu_sum = float(x.float().sum()) | |||
# More efficient version that can be used if .sum() returns a Python scalar | |||
# cpu_sum = float(x.sum()) | |||
except RuntimeError as instance: | |||
# We want to check if inst is actually an overflow exception. | |||
# RuntimeError could come from a different error. | |||
# If so, we still want the exception to propagate. | |||
if 'value cannot be converted' not in instance.args[0]: | |||
raise | |||
return True | |||
else: | |||
if cpu_sum == float( | |||
'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |||
return True | |||
return False | |||
# `overflow` is boolean indicating whether the gradient overflowed | |||
def update_scale(self, overflow): | |||
if not hasattr(self, 'min_scale'): | |||
self.min_scale = 1 | |||
if not hasattr(self, 'delayed_shift'): | |||
self.delayed_shift = 1 | |||
if not hasattr(self, 'cur_hysteresis'): | |||
self.cur_hysteresis = 1 | |||
if not hasattr(self, 'consecutive_hysteresis'): | |||
self.consecutive_hysteresis = True | |||
if overflow: | |||
# self.cur_scale /= self.scale_factor | |||
if self.delayed_shift == 1 or self.cur_hysteresis == 1: | |||
self.cur_scale = max(self.cur_scale / self.scale_factor, | |||
self.min_scale) | |||
else: | |||
self.cur_hysteresis -= 1 | |||
self.last_overflow_iter = self.cur_iter | |||
else: | |||
if self.consecutive_hysteresis: | |||
self.cur_hysteresis = self.delayed_shift | |||
if (self.cur_iter | |||
- self.last_overflow_iter) % self.scale_window == 0: | |||
if not self.consecutive_hysteresis: | |||
self.cur_hysteresis = self.delayed_shift | |||
self.cur_scale *= self.scale_factor | |||
self.cur_iter += 1 | |||
@property | |||
def loss_scale(self): | |||
return self.cur_scale | |||
def scale_gradient(self, module, grad_in, grad_out): | |||
return tuple(self.loss_scale * g for g in grad_in) | |||
def backward(self, loss, retain_graph=False): | |||
scaled_loss = loss * self.loss_scale | |||
scaled_loss.backward(retain_graph=retain_graph) | |||
############################################################## | |||
# Example usage below here -- assuming it's in a separate file | |||
############################################################## | |||
""" | |||
TO-DO separate out into an example. | |||
if __name__ == "__main__": | |||
import torch | |||
from torch.autograd import Variable | |||
from dynamic_loss_scaler import DynamicLossScaler | |||
# N is batch size; D_in is input dimension; | |||
# H is hidden dimension; D_out is output dimension. | |||
N, D_in, H, D_out = 64, 1000, 100, 10 | |||
# Create random Tensors to hold inputs and outputs, and wrap them in Variables. | |||
x = Variable(torch.randn(N, D_in), requires_grad=False) | |||
y = Variable(torch.randn(N, D_out), requires_grad=False) | |||
w1 = Variable(torch.randn(D_in, H), requires_grad=True) | |||
w2 = Variable(torch.randn(H, D_out), requires_grad=True) | |||
parameters = [w1, w2] | |||
learning_rate = 1e-6 | |||
optimizer = torch.optim.SGD(parameters, lr=learning_rate) | |||
loss_scaler = DynamicLossScaler() | |||
for t in range(500): | |||
y_pred = x.mm(w1).clamp(min=0).mm(w2) | |||
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale | |||
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) | |||
print('Iter {} scaled loss: {}'.format(t, loss.data[0])) | |||
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) | |||
# Run backprop | |||
optimizer.zero_grad() | |||
loss.backward() | |||
# Check for overflow | |||
has_overflow = DynamicLossScaler.has_overflow(parameters) | |||
# If no overflow, unscale grad and update as usual | |||
if not has_overflow: | |||
for param in parameters: | |||
param.grad.data.mul_(1. / loss_scaler.loss_scale) | |||
optimizer.step() | |||
# Otherwise, don't do anything -- ie, skip iteration | |||
else: | |||
print('OVERFLOW!') | |||
# Update loss scale for next iteration | |||
loss_scaler.update_scale(has_overflow) | |||
""" |
@@ -5,6 +5,7 @@ pycocotools>=2.0.4 | |||
# rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
# which introduced compatability issues that are being investigated | |||
rouge_score<=0.0.4 | |||
sacrebleu | |||
taming-transformers-rom1504 | |||
timm | |||
tokenizers | |||
@@ -0,0 +1,105 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import unittest | |||
import json | |||
from modelscope.metainfo import Metrics, Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestOfaTrainer(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.finetune_cfg = \ | |||
{'framework': 'pytorch', | |||
'task': 'image-captioning', | |||
'model': {'type': 'ofa', | |||
'beam_search': {'beam_size': 5, | |||
'max_len_b': 16, | |||
'min_len': 1, | |||
'no_repeat_ngram_size': 0}, | |||
'seed': 7, | |||
'max_src_length': 256, | |||
'language': 'en', | |||
'gen_type': 'generation', | |||
'patch_image_size': 480, | |||
'max_image_size': 480, | |||
'imagenet_default_mean_and_std': False}, | |||
'pipeline': {'type': 'image-captioning'}, | |||
'dataset': {'column_map': {'text': 'caption'}}, | |||
'train': {'work_dir': 'work/ckpts/caption', | |||
# 'launcher': 'pytorch', | |||
'max_epochs': 1, | |||
'use_fp16': True, | |||
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, | |||
'lr_scheduler': {'name': 'polynomial_decay', | |||
'warmup_proportion': 0.01, | |||
'lr_end': 1e-07}, | |||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, | |||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||
'cumulative_iters': 1, | |||
'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, | |||
'loss_keys': 'loss'}, | |||
'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', | |||
'constraint_range': None, | |||
'drop_worst_after': 0, | |||
'drop_worst_ratio': 0.0, | |||
'ignore_eos': False, | |||
'ignore_prefix_size': 0, | |||
'label_smoothing': 0.1, | |||
'reg_alpha': 1.0, | |||
'report_accuracy': False, | |||
'sample_patch_num': 196, | |||
'sentence_avg': False, | |||
'use_rdrop': False}, | |||
'hooks': [{'type': 'BestCkptSaverHook', | |||
'metric_key': 'bleu-4', | |||
'interval': 100}, | |||
{'type': 'TextLoggerHook', 'interval': 1}, | |||
{'type': 'IterTimerHook'}, | |||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | |||
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||
'metrics': [{'type': 'bleu', | |||
'eval_tokenized_bleu': False, | |||
'ref_name': 'labels', | |||
'hyp_name': 'caption'}]}, | |||
'preprocessor': []} | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_std(self): | |||
WORKSPACE = './workspace/ckpts/caption' | |||
os.makedirs(WORKSPACE, exist_ok=True) | |||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
with open(config_file, 'w') as writer: | |||
json.dump(self.finetune_cfg, writer) | |||
pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' | |||
args = dict( | |||
model=pretrained_model, | |||
work_dir=WORKSPACE, | |||
train_dataset=MsDataset.load( | |||
'coco_2014_caption', | |||
namespace='modelscope', | |||
split='train[:20]'), | |||
eval_dataset=MsDataset.load( | |||
'coco_2014_caption', | |||
namespace='modelscope', | |||
split='validation[:10]'), | |||
metrics=[Metrics.BLEU], | |||
cfg_file=config_file) | |||
trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||
trainer.train() | |||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||
os.listdir(os.path.join(WORKSPACE, 'output'))) | |||
shutil.rmtree(WORKSPACE) | |||
if __name__ == '__main__': | |||
unittest.main() |