1. Fix bugs in daily test 2. Fix a bug that the updating of lr is before the first time of updating of optimizer TODO this will still cause warnings when GA is above 1 3. Remove the judgement of mode in text-classification's preprocessor to fit the base trainer(Bug) Update some regression bins to fit the preprocessor 4. Update the regression tool to let outer code modify atol and rtol 5. Add the default metric for text-classification task 6. Remove the useless ckpt conversion method in bert to avoid the requirement of tf when loading modeling_bert Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10430764master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 | |||||
size 151572 |
@@ -1,3 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||||
size 62231 | |||||
oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b | |||||
size 61741 |
@@ -1,3 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||||
size 62235 | |||||
oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 | |||||
size 61745 |
@@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
def generate_dummy_inputs(self, | def generate_dummy_inputs(self, | ||||
shape: Tuple = None, | shape: Tuple = None, | ||||
pair: bool = False, | |||||
**kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | """Generate dummy inputs for model exportation to onnx or other formats by tracing. | ||||
@param shape: A tuple of input shape which should have at most two dimensions. | @param shape: A tuple of input shape which should have at most two dimensions. | ||||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | ||||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | ||||
@param pair: Generate sentence pairs or single sentences for dummy inputs. | |||||
@return: Dummy inputs. | @return: Dummy inputs. | ||||
""" | """ | ||||
@@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||||
**sequence_length | **sequence_length | ||||
}) | }) | ||||
preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | ||||
if preprocessor.pair: | |||||
if pair: | |||||
first_sequence = preprocessor.tokenizer.unk_token | first_sequence = preprocessor.tokenizer.unk_token | ||||
second_sequence = preprocessor.tokenizer.unk_token | second_sequence = preprocessor.tokenizer.unk_token | ||||
else: | else: | ||||
@@ -32,6 +32,7 @@ task_default_metrics = { | |||||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
Tasks.token_classification: [Metrics.token_cls_metric], | Tasks.token_classification: [Metrics.token_cls_metric], | ||||
Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
Tasks.text_classification: [Metrics.seq_cls_metric], | |||||
Tasks.image_denoising: [Metrics.image_denoise_metric], | Tasks.image_denoising: [Metrics.image_denoise_metric], | ||||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | ||||
Tasks.image_portrait_enhancement: | Tasks.image_portrait_enhancement: | ||||
@@ -15,7 +15,6 @@ | |||||
"""PyTorch BERT model. """ | """PyTorch BERT model. """ | ||||
import math | import math | ||||
import os | |||||
import warnings | import warnings | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Optional, Tuple | from typing import Optional, Tuple | ||||
@@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, | |||||
find_pruneable_heads_and_indices, | find_pruneable_heads_and_indices, | ||||
prune_linear_layer) | prune_linear_layer) | ||||
from modelscope.models.base import TorchModel | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from .configuration_bert import BertConfig | from .configuration_bert import BertConfig | ||||
@@ -50,81 +48,6 @@ logger = get_logger(__name__) | |||||
_CONFIG_FOR_DOC = 'BertConfig' | _CONFIG_FOR_DOC = 'BertConfig' | ||||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |||||
"""Load tf checkpoints in a pytorch model.""" | |||||
try: | |||||
import re | |||||
import numpy as np | |||||
import tensorflow as tf | |||||
except ImportError: | |||||
logger.error( | |||||
'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' | |||||
'https://www.tensorflow.org/install/ for installation instructions.' | |||||
) | |||||
raise | |||||
tf_path = os.path.abspath(tf_checkpoint_path) | |||||
logger.info(f'Converting TensorFlow checkpoint from {tf_path}') | |||||
# Load weights from TF model | |||||
init_vars = tf.train.list_variables(tf_path) | |||||
names = [] | |||||
arrays = [] | |||||
for name, shape in init_vars: | |||||
logger.info(f'Loading TF weight {name} with shape {shape}') | |||||
array = tf.train.load_variable(tf_path, name) | |||||
names.append(name) | |||||
arrays.append(array) | |||||
for name, array in zip(names, arrays): | |||||
name = name.split('/') | |||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |||||
# which are not required for using pretrained model | |||||
if any(n in [ | |||||
'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', | |||||
'AdamWeightDecayOptimizer_1', 'global_step' | |||||
] for n in name): | |||||
logger.info(f"Skipping {'/'.join(name)}") | |||||
continue | |||||
pointer = model | |||||
for m_name in name: | |||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): | |||||
scope_names = re.split(r'_(\d+)', m_name) | |||||
else: | |||||
scope_names = [m_name] | |||||
if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': | |||||
pointer = getattr(pointer, 'weight') | |||||
elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': | |||||
pointer = getattr(pointer, 'bias') | |||||
elif scope_names[0] == 'output_weights': | |||||
pointer = getattr(pointer, 'weight') | |||||
elif scope_names[0] == 'squad': | |||||
pointer = getattr(pointer, 'classifier') | |||||
else: | |||||
try: | |||||
pointer = getattr(pointer, scope_names[0]) | |||||
except AttributeError: | |||||
logger.info(f"Skipping {'/'.join(name)}") | |||||
continue | |||||
if len(scope_names) >= 2: | |||||
num = int(scope_names[1]) | |||||
pointer = pointer[num] | |||||
if m_name[-11:] == '_embeddings': | |||||
pointer = getattr(pointer, 'weight') | |||||
elif m_name == 'kernel': | |||||
array = np.transpose(array) | |||||
try: | |||||
if pointer.shape != array.shape: | |||||
raise ValueError( | |||||
f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' | |||||
) | |||||
except AssertionError as e: | |||||
e.args += (pointer.shape, array.shape) | |||||
raise | |||||
logger.info(f'Initialize PyTorch weight {name}') | |||||
pointer.data = torch.from_numpy(array) | |||||
return model | |||||
class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
"""Construct the embeddings from word, position and token_type embeddings.""" | """Construct the embeddings from word, position and token_type embeddings.""" | ||||
@@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): | |||||
""" | """ | ||||
config_class = BertConfig | config_class = BertConfig | ||||
load_tf_weights = load_tf_weights_in_bert | |||||
base_model_prefix = 'bert' | base_model_prefix = 'bert' | ||||
supports_gradient_checkpointing = True | supports_gradient_checkpointing = True | ||||
_keys_to_ignore_on_load_missing = [r'position_ids'] | _keys_to_ignore_on_load_missing = [r'position_ids'] | ||||
@@ -2,7 +2,7 @@ | |||||
import os.path as osp | import os.path as osp | ||||
import re | import re | ||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
import sentencepiece as spm | import sentencepiece as spm | ||||
@@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
return isinstance(label, str) or isinstance(label, int) | return isinstance(label, str) or isinstance(label, int) | ||||
if labels is not None: | if labels is not None: | ||||
if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ | |||||
and self.label2id is not None: | and self.label2id is not None: | ||||
output[OutputKeys.LABELS] = [ | output[OutputKeys.LABELS] = [ | ||||
self.label2id[str(label)] for label in labels | self.label2id[str(label)] for label in labels | ||||
@@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | ||||
kwargs['truncation'] = kwargs.get('truncation', True) | kwargs['truncation'] = kwargs.get('truncation', True) | ||||
kwargs['padding'] = kwargs.get( | |||||
'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||||
kwargs['padding'] = kwargs.get('padding', 'max_length') | |||||
kwargs['max_length'] = kwargs.pop('sequence_length', 128) | kwargs['max_length'] = kwargs.pop('sequence_length', 128) | ||||
super().__init__(model_dir, mode=mode, **kwargs) | super().__init__(model_dir, mode=mode, **kwargs) | ||||
@@ -47,7 +47,7 @@ class LrSchedulerHook(Hook): | |||||
return lr | return lr | ||||
def before_train_iter(self, trainer): | def before_train_iter(self, trainer): | ||||
if not self.by_epoch: | |||||
if not self.by_epoch and trainer.iter > 0: | |||||
if self.warmup_lr_scheduler is not None: | if self.warmup_lr_scheduler is not None: | ||||
self.warmup_lr_scheduler.step() | self.warmup_lr_scheduler.step() | ||||
else: | else: | ||||
@@ -651,7 +651,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
# TODO: support MsDataset load for cv | # TODO: support MsDataset load for cv | ||||
if hasattr(data_cfg, 'name'): | if hasattr(data_cfg, 'name'): | ||||
dataset = MsDataset.load( | dataset = MsDataset.load( | ||||
dataset_name=data_cfg.name, | |||||
dataset_name=data_cfg.pop('name'), | |||||
**data_cfg, | **data_cfg, | ||||
) | ) | ||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | ||||
@@ -65,7 +65,8 @@ class RegressTool: | |||||
def monitor_module_single_forward(self, | def monitor_module_single_forward(self, | ||||
module: nn.Module, | module: nn.Module, | ||||
file_name: str, | file_name: str, | ||||
compare_fn=None): | |||||
compare_fn=None, | |||||
**kwargs): | |||||
"""Monitor a pytorch module in a single forward. | """Monitor a pytorch module in a single forward. | ||||
@param module: A torch module | @param module: A torch module | ||||
@@ -107,7 +108,7 @@ class RegressTool: | |||||
baseline = os.path.join(tempfile.gettempdir(), name) | baseline = os.path.join(tempfile.gettempdir(), name) | ||||
self.load(baseline, name) | self.load(baseline, name) | ||||
with open(baseline, 'rb') as f: | with open(baseline, 'rb') as f: | ||||
baseline_json = pickle.load(f) | |||||
base = pickle.load(f) | |||||
class NumpyEncoder(json.JSONEncoder): | class NumpyEncoder(json.JSONEncoder): | ||||
"""Special json encoder for numpy types | """Special json encoder for numpy types | ||||
@@ -122,9 +123,9 @@ class RegressTool: | |||||
return obj.tolist() | return obj.tolist() | ||||
return json.JSONEncoder.default(self, obj) | return json.JSONEncoder.default(self, obj) | ||||
print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') | |||||
print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') | |||||
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') | ||||
if not compare_io_and_print(baseline_json, io_json, compare_fn): | |||||
if not compare_io_and_print(base, io_json, compare_fn, **kwargs): | |||||
raise ValueError('Result not match!') | raise ValueError('Result not match!') | ||||
@contextlib.contextmanager | @contextlib.contextmanager | ||||
@@ -136,7 +137,8 @@ class RegressTool: | |||||
ignore_keys=None, | ignore_keys=None, | ||||
compare_random=True, | compare_random=True, | ||||
reset_dropout=True, | reset_dropout=True, | ||||
lazy_stop_callback=None): | |||||
lazy_stop_callback=None, | |||||
**kwargs): | |||||
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer. | """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. | ||||
This is usually useful when you try to change some dangerous code | This is usually useful when you try to change some dangerous code | ||||
@@ -265,14 +267,15 @@ class RegressTool: | |||||
baseline_json = pickle.load(f) | baseline_json = pickle.load(f) | ||||
if level == 'strict' and not compare_io_and_print( | if level == 'strict' and not compare_io_and_print( | ||||
baseline_json['forward'], io_json, compare_fn): | |||||
baseline_json['forward'], io_json, compare_fn, **kwargs): | |||||
raise RuntimeError('Forward not match!') | raise RuntimeError('Forward not match!') | ||||
if not compare_backward_and_print( | if not compare_backward_and_print( | ||||
baseline_json['backward'], | baseline_json['backward'], | ||||
bw_json, | bw_json, | ||||
compare_fn=compare_fn, | compare_fn=compare_fn, | ||||
ignore_keys=ignore_keys, | ignore_keys=ignore_keys, | ||||
level=level): | |||||
level=level, | |||||
**kwargs): | |||||
raise RuntimeError('Backward not match!') | raise RuntimeError('Backward not match!') | ||||
cfg_opt1 = { | cfg_opt1 = { | ||||
'optimizer': baseline_json['optimizer'], | 'optimizer': baseline_json['optimizer'], | ||||
@@ -286,7 +289,8 @@ class RegressTool: | |||||
'cfg': summary['cfg'], | 'cfg': summary['cfg'], | ||||
'state': None if not compare_random else summary['state'] | 'state': None if not compare_random else summary['state'] | ||||
} | } | ||||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): | |||||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, | |||||
**kwargs): | |||||
raise RuntimeError('Cfg or optimizers not match!') | raise RuntimeError('Cfg or optimizers not match!') | ||||
@@ -303,7 +307,8 @@ class MsRegressTool(RegressTool): | |||||
compare_fn=None, | compare_fn=None, | ||||
ignore_keys=None, | ignore_keys=None, | ||||
compare_random=True, | compare_random=True, | ||||
lazy_stop_callback=None): | |||||
lazy_stop_callback=None, | |||||
**kwargs): | |||||
if lazy_stop_callback is None: | if lazy_stop_callback is None: | ||||
@@ -319,7 +324,7 @@ class MsRegressTool(RegressTool): | |||||
trainer.register_hook(EarlyStopHook()) | trainer.register_hook(EarlyStopHook()) | ||||
def _train_loop(trainer, *args, **kwargs): | |||||
def _train_loop(trainer, *args_train, **kwargs_train): | |||||
with self.monitor_module_train( | with self.monitor_module_train( | ||||
trainer, | trainer, | ||||
file_name, | file_name, | ||||
@@ -327,9 +332,11 @@ class MsRegressTool(RegressTool): | |||||
compare_fn=compare_fn, | compare_fn=compare_fn, | ||||
ignore_keys=ignore_keys, | ignore_keys=ignore_keys, | ||||
compare_random=compare_random, | compare_random=compare_random, | ||||
lazy_stop_callback=lazy_stop_callback): | |||||
lazy_stop_callback=lazy_stop_callback, | |||||
**kwargs): | |||||
try: | try: | ||||
return trainer.train_loop_origin(*args, **kwargs) | |||||
return trainer.train_loop_origin(*args_train, | |||||
**kwargs_train) | |||||
except MsRegressTool.EarlyStopError: | except MsRegressTool.EarlyStopError: | ||||
pass | pass | ||||
@@ -530,7 +537,8 @@ def compare_arguments_nested(print_content, | |||||
) | ) | ||||
return False | return False | ||||
if not all([ | if not all([ | ||||
compare_arguments_nested(None, sub_arg1, sub_arg2) | |||||
compare_arguments_nested( | |||||
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) | |||||
for sub_arg1, sub_arg2 in zip(arg1, arg2) | for sub_arg1, sub_arg2 in zip(arg1, arg2) | ||||
]): | ]): | ||||
if print_content is not None: | if print_content is not None: | ||||
@@ -551,7 +559,8 @@ def compare_arguments_nested(print_content, | |||||
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') | print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') | ||||
return False | return False | ||||
if not all([ | if not all([ | ||||
compare_arguments_nested(None, arg1[key], arg2[key]) | |||||
compare_arguments_nested( | |||||
None, arg1[key], arg2[key], rtol=rtol, atol=atol) | |||||
for key in keys1 | for key in keys1 | ||||
]): | ]): | ||||
if print_content is not None: | if print_content is not None: | ||||
@@ -574,7 +583,7 @@ def compare_arguments_nested(print_content, | |||||
raise ValueError(f'type not supported: {type1}') | raise ValueError(f'type not supported: {type1}') | ||||
def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): | |||||
if compare_fn is None: | if compare_fn is None: | ||||
def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
@@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
else: | else: | ||||
match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
f'unmatched module {key} input args', v1input['args'], | f'unmatched module {key} input args', v1input['args'], | ||||
v2input['args']) and match | |||||
v2input['args'], **kwargs) and match | |||||
match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
f'unmatched module {key} input kwargs', v1input['kwargs'], | f'unmatched module {key} input kwargs', v1input['kwargs'], | ||||
v2input['kwargs']) and match | |||||
v2input['kwargs'], **kwargs) and match | |||||
v1output = numpify_tensor_nested(v1['output']) | v1output = numpify_tensor_nested(v1['output']) | ||||
v2output = numpify_tensor_nested(v2['output']) | v2output = numpify_tensor_nested(v2['output']) | ||||
res = compare_fn(v1output, v2output, key, 'output') | res = compare_fn(v1output, v2output, key, 'output') | ||||
@@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): | |||||
) | ) | ||||
match = match and res | match = match and res | ||||
else: | else: | ||||
match = compare_arguments_nested(f'unmatched module {key} outputs', | |||||
v1output, v2output) and match | |||||
match = compare_arguments_nested( | |||||
f'unmatched module {key} outputs', | |||||
arg1=v1output, | |||||
arg2=v2output, | |||||
**kwargs) and match | |||||
return match | return match | ||||
@@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json, | |||||
bw_json, | bw_json, | ||||
level, | level, | ||||
ignore_keys=None, | ignore_keys=None, | ||||
compare_fn=None): | |||||
compare_fn=None, | |||||
**kwargs): | |||||
if compare_fn is None: | if compare_fn is None: | ||||
def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
@@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json, | |||||
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ | data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ | ||||
'grad'], bw_json[key]['data_after'] | 'grad'], bw_json[key]['data_after'] | ||||
match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
f'unmatched module {key} tensor data', data1, data2) and match | |||||
f'unmatched module {key} tensor data', | |||||
arg1=data1, | |||||
arg2=data2, | |||||
**kwargs) and match | |||||
if level == 'strict': | if level == 'strict': | ||||
match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
f'unmatched module {key} grad data', grad1, | |||||
grad2) and match | |||||
f'unmatched module {key} grad data', | |||||
arg1=grad1, | |||||
arg2=grad2, | |||||
**kwargs) and match | |||||
match = compare_arguments_nested( | match = compare_arguments_nested( | ||||
f'unmatched module {key} data after step', data_after1, | f'unmatched module {key} data after step', data_after1, | ||||
data_after2) and match | |||||
data_after2, **kwargs) and match | |||||
return match | return match | ||||
def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
def compare_cfg_and_optimizers(baseline_json, | |||||
cfg_json, | |||||
compare_fn=None, | |||||
**kwargs): | |||||
if compare_fn is None: | if compare_fn is None: | ||||
def compare_fn(*args, **kwargs): | def compare_fn(*args, **kwargs): | ||||
@@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
print( | print( | ||||
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" | f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" | ||||
) | ) | ||||
match = compare_arguments_nested('unmatched optimizer defaults', | |||||
optimizer1['defaults'], | |||||
optimizer2['defaults']) and match | |||||
match = compare_arguments_nested('unmatched optimizer state_dict', | |||||
optimizer1['state_dict'], | |||||
optimizer2['state_dict']) and match | |||||
match = compare_arguments_nested( | |||||
'unmatched optimizer defaults', optimizer1['defaults'], | |||||
optimizer2['defaults'], **kwargs) and match | |||||
match = compare_arguments_nested( | |||||
'unmatched optimizer state_dict', optimizer1['state_dict'], | |||||
optimizer2['state_dict'], **kwargs) and match | |||||
res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') | res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') | ||||
if res is not None: | if res is not None: | ||||
@@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
print( | print( | ||||
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" | f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" | ||||
) | ) | ||||
match = compare_arguments_nested('unmatched lr_scheduler state_dict', | |||||
lr_scheduler1['state_dict'], | |||||
lr_scheduler2['state_dict']) and match | |||||
match = compare_arguments_nested( | |||||
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], | |||||
lr_scheduler2['state_dict'], **kwargs) and match | |||||
res = compare_fn(cfg1, cfg2, None, 'cfg') | res = compare_fn(cfg1, cfg2, None, 'cfg') | ||||
if res is not None: | if res is not None: | ||||
print(f'cfg compared with user compare_fn with result:{res}\n') | print(f'cfg compared with user compare_fn with result:{res}\n') | ||||
match = match and res | match = match and res | ||||
else: | else: | ||||
match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match | |||||
match = compare_arguments_nested( | |||||
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match | |||||
res = compare_fn(state1, state2, None, 'state') | res = compare_fn(state1, state2, None, 'state') | ||||
if res is not None: | if res is not None: | ||||
@@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): | |||||
match = match and res | match = match and res | ||||
else: | else: | ||||
match = compare_arguments_nested('unmatched random state', state1, | match = compare_arguments_nested('unmatched random state', state1, | ||||
state2) and match | |||||
state2, **kwargs) and match | |||||
return match | return match |
@@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_ms_csv_basic(self): | def test_ms_csv_basic(self): | ||||
ms_ds_train = MsDataset.load( | ms_ds_train = MsDataset.load( | ||||
'afqmc_small', namespace='userxiaoming', split='train') | |||||
'clue', subset_name='afqmc', | |||||
split='train').to_hf_dataset().select(range(5)) | |||||
print(next(iter(ms_ds_train))) | print(next(iter(ms_ds_train))) | ||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
@@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ | |||||
calculate_fisher | calculate_fisher | ||||
from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
from modelscope.utils.data_utils import to_device | from modelscope.utils.data_utils import to_device | ||||
from modelscope.utils.regress_test_utils import MsRegressTool | |||||
from modelscope.utils.regress_test_utils import (MsRegressTool, | |||||
compare_arguments_nested) | |||||
from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
@@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
def test_trainer_repeatable(self): | def test_trainer_repeatable(self): | ||||
import torch # noqa | import torch # noqa | ||||
def compare_fn(value1, value2, key, type): | |||||
# Ignore the differences between optimizers of two torch versions | |||||
if type != 'optimizer': | |||||
return None | |||||
match = (value1['type'] == value2['type']) | |||||
shared_defaults = set(value1['defaults'].keys()).intersection( | |||||
set(value2['defaults'].keys())) | |||||
match = all([ | |||||
compare_arguments_nested(f'Optimizer defaults {key} not match', | |||||
value1['defaults'][key], | |||||
value2['defaults'][key]) | |||||
for key in shared_defaults | |||||
]) and match | |||||
match = (len(value1['state_dict']['param_groups']) == len( | |||||
value2['state_dict']['param_groups'])) and match | |||||
for group1, group2 in zip(value1['state_dict']['param_groups'], | |||||
value2['state_dict']['param_groups']): | |||||
shared_keys = set(group1.keys()).intersection( | |||||
set(group2.keys())) | |||||
match = all([ | |||||
compare_arguments_nested( | |||||
f'Optimizer param_groups {key} not match', group1[key], | |||||
group2[key]) for key in shared_keys | |||||
]) and match | |||||
return match | |||||
def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
cfg.task = 'nli' | cfg.task = 'nli' | ||||
cfg['preprocessor'] = {'type': 'nli-tokenizer'} | cfg['preprocessor'] = {'type': 'nli-tokenizer'} | ||||
@@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
name=Trainers.nlp_base_trainer, default_args=kwargs) | name=Trainers.nlp_base_trainer, default_args=kwargs) | ||||
with self.regress_tool.monitor_ms_train( | with self.regress_tool.monitor_ms_train( | ||||
trainer, 'sbert-base-tnews', level='strict'): | |||||
trainer, 'sbert-base-tnews', level='strict', | |||||
compare_fn=compare_fn): | |||||
trainer.train() | trainer.train() | ||||
def finetune(self, | def finetune(self, | ||||
@@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
self.dataset = MsDataset.load( | self.dataset = MsDataset.load( | ||||
'afqmc_small', namespace='userxiaoming', split='train') | |||||
'clue', subset_name='afqmc', | |||||
split='train').to_hf_dataset().select(range(2)) | |||||
def tearDown(self): | def tearDown(self): | ||||
shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
@@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | ||||
pipeline_sentence_similarity(output_dir) | pipeline_sentence_similarity(output_dir) | ||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level') | |||||
def test_trainer_with_backbone_head(self): | def test_trainer_with_backbone_head(self): | ||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
kwargs = dict( | kwargs = dict( | ||||
@@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
cfg = read_config(model_id, revision='beta') | cfg = read_config(model_id, revision='beta') | ||||
cfg.train.max_epochs = 20 | cfg.train.max_epochs = 20 | ||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
cfg.dump(cfg_file) | cfg.dump(cfg_file) | ||||
@@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | ||||
self.assertTrue(Metrics.accuracy in eval_results) | self.assertTrue(Metrics.accuracy in eval_results) | ||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_trainer_with_configured_datasets(self): | def test_trainer_with_configured_datasets(self): | ||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
cfg: Config = read_config(model_id) | cfg: Config = read_config(model_id) | ||||
cfg.train.max_epochs = 20 | cfg.train.max_epochs = 20 | ||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
cfg.dataset = { | cfg.dataset = { | ||||
'train': { | 'train': { | ||||
'name': 'afqmc_small', | |||||
'name': 'clue', | |||||
'subset_name': 'afqmc', | |||||
'split': 'train', | 'split': 'train', | ||||
'namespace': 'userxiaoming' | |||||
}, | }, | ||||
'val': { | 'val': { | ||||
'name': 'afqmc_small', | |||||
'name': 'clue', | |||||
'subset_name': 'afqmc', | |||||
'split': 'train', | 'split': 'train', | ||||
'namespace': 'userxiaoming' | |||||
}, | }, | ||||
} | } | ||||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
@@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
cfg: Config = read_config(model_id) | cfg: Config = read_config(model_id) | ||||
cfg.train.max_epochs = 3 | cfg.train.max_epochs = 3 | ||||
cfg.preprocessor.first_sequence = 'sentence1' | |||||
cfg.preprocessor.second_sequence = 'sentence2' | |||||
cfg.preprocessor.label = 'label' | |||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||||
cfg.train.work_dir = self.tmp_dir | cfg.train.work_dir = self.tmp_dir | ||||
cfg_file = os.path.join(self.tmp_dir, 'config.json') | cfg_file = os.path.join(self.tmp_dir, 'config.json') | ||||
cfg.dump(cfg_file) | cfg.dump(cfg_file) | ||||