From fbde374659b31466f48124c79cc26c852553ca9f Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 30 Aug 2022 23:17:07 +0800 Subject: [PATCH] [to #42322933] add regress tests Add regression test for some unit tests. Firstly, Run a baseline test to create a pickle file which contains the inputs and outputs of modules, then changes can be observed between the latest version and the baseline file. Some baseline files are submitted in the data/test/regression folder Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9814693 --- .gitattributes | 2 + data/test/regression/fill_mask_bert_zh.bin | 3 + data/test/regression/fill_mask_sbert_en.bin | 3 + data/test/regression/fill_mask_sbert_zh.bin | 3 + data/test/regression/fill_mask_veco_en.bin | 3 + data/test/regression/fill_mask_veco_zh.bin | 3 + data/test/regression/sbert_nli.bin | 3 + data/test/regression/sbert_sen_sim.bin | 3 + data/test/regression/sbert_ws_en.bin | 3 + data/test/regression/sbert_ws_zh.bin | 3 + data/test/regression/sbert_zero_shot.bin | 3 + modelscope/utils/regress_test_utils.py | 703 ++++++++++++++++++ tests/pipelines/test_fill_mask.py | 23 +- tests/pipelines/test_nli.py | 7 +- tests/pipelines/test_sentence_similarity.py | 6 +- .../test_sentiment_classification.py | 1 - tests/pipelines/test_word_segmentation.py | 11 +- .../test_zero_shot_classification.py | 9 +- tests/run.py | 1 + 19 files changed, 777 insertions(+), 16 deletions(-) create mode 100644 data/test/regression/fill_mask_bert_zh.bin create mode 100644 data/test/regression/fill_mask_sbert_en.bin create mode 100644 data/test/regression/fill_mask_sbert_zh.bin create mode 100644 data/test/regression/fill_mask_veco_en.bin create mode 100644 data/test/regression/fill_mask_veco_zh.bin create mode 100644 data/test/regression/sbert_nli.bin create mode 100644 data/test/regression/sbert_sen_sim.bin create mode 100644 data/test/regression/sbert_ws_en.bin create mode 100644 data/test/regression/sbert_ws_zh.bin create mode 100644 data/test/regression/sbert_zero_shot.bin create mode 100644 modelscope/utils/regress_test_utils.py diff --git a/.gitattributes b/.gitattributes index 60ff0dd2..1a3015ec 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,4 +4,6 @@ *.wav filter=lfs diff=lfs merge=lfs -text *.JPEG filter=lfs diff=lfs merge=lfs -text *.jpeg filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text *.avi filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text diff --git a/data/test/regression/fill_mask_bert_zh.bin b/data/test/regression/fill_mask_bert_zh.bin new file mode 100644 index 00000000..17c28b81 --- /dev/null +++ b/data/test/regression/fill_mask_bert_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:541183383bb06aa3ca2c44a68cd51c1be5e3e984a1dee2c58092b9552660f3ce +size 61883 diff --git a/data/test/regression/fill_mask_sbert_en.bin b/data/test/regression/fill_mask_sbert_en.bin new file mode 100644 index 00000000..09aaf300 --- /dev/null +++ b/data/test/regression/fill_mask_sbert_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f0afcd9d2aa5ac9569114203bd9db4f1a520c903a88fd4854370cdde0e7eab7 +size 119940 diff --git a/data/test/regression/fill_mask_sbert_zh.bin b/data/test/regression/fill_mask_sbert_zh.bin new file mode 100644 index 00000000..812f7ba2 --- /dev/null +++ b/data/test/regression/fill_mask_sbert_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 +size 119940 diff --git a/data/test/regression/fill_mask_veco_en.bin b/data/test/regression/fill_mask_veco_en.bin new file mode 100644 index 00000000..be3fddc8 --- /dev/null +++ b/data/test/regression/fill_mask_veco_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 +size 119619 diff --git a/data/test/regression/fill_mask_veco_zh.bin b/data/test/regression/fill_mask_veco_zh.bin new file mode 100644 index 00000000..c0d27e20 --- /dev/null +++ b/data/test/regression/fill_mask_veco_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a +size 119619 diff --git a/data/test/regression/sbert_nli.bin b/data/test/regression/sbert_nli.bin new file mode 100644 index 00000000..a5f680bb --- /dev/null +++ b/data/test/regression/sbert_nli.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 +size 62231 diff --git a/data/test/regression/sbert_sen_sim.bin b/data/test/regression/sbert_sen_sim.bin new file mode 100644 index 00000000..a59cbe0b --- /dev/null +++ b/data/test/regression/sbert_sen_sim.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a +size 62235 diff --git a/data/test/regression/sbert_ws_en.bin b/data/test/regression/sbert_ws_en.bin new file mode 100644 index 00000000..4eb562d6 --- /dev/null +++ b/data/test/regression/sbert_ws_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572 +size 60801 diff --git a/data/test/regression/sbert_ws_zh.bin b/data/test/regression/sbert_ws_zh.bin new file mode 100644 index 00000000..555f640d --- /dev/null +++ b/data/test/regression/sbert_ws_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c +size 60801 diff --git a/data/test/regression/sbert_zero_shot.bin b/data/test/regression/sbert_zero_shot.bin new file mode 100644 index 00000000..23d40946 --- /dev/null +++ b/data/test/regression/sbert_zero_shot.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 +size 61589 diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py new file mode 100644 index 00000000..ca50d579 --- /dev/null +++ b/modelscope/utils/regress_test_utils.py @@ -0,0 +1,703 @@ +import contextlib +import hashlib +import os +import pickle +import random +import shutil +import tempfile +from collections.abc import Mapping +from pathlib import Path +from types import FunctionType +from typing import Any, Dict, Union + +import json +import numpy as np +import torch.optim +from torch import nn + + +class RegressTool: + """This class is used to stop inference/training results from changing by some unaware affections by unittests. + + Firstly, run a baseline test to create a result file, then changes can be observed between + the latest version and the baseline file. + """ + + def __init__(self, + baseline: bool = None, + store_func: FunctionType = None, + load_func: FunctionType = None): + """A func to store the baseline file and a func to load the baseline file. + """ + self.baseline = baseline + self.store_func = store_func + self.load_func = load_func + print(f'Current working dir is: {Path.cwd()}') + + def store(self, local, remote): + if self.store_func is not None: + self.store_func(local, remote) + else: + path = os.path.abspath( + os.path.join(Path.cwd(), 'data', 'test', 'regression')) + os.makedirs(path, exist_ok=True) + shutil.copy(local, os.path.join(path, remote)) + + def load(self, local, remote): + if self.load_func is not None: + self.load_func(local, remote) + else: + path = os.path.abspath( + os.path.join(Path.cwd(), 'data', 'test', 'regression')) + baseline = os.path.join(path, remote) + if not os.path.exists(baseline): + raise ValueError(f'base line file {baseline} not exist') + print( + f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}' + ) + if os.path.exists(local): + os.remove(local) + os.symlink(baseline, local, target_is_directory=False) + + @contextlib.contextmanager + def monitor_module_single_forward(self, + module: nn.Module, + file_name: str, + compare_fn=None): + """Monitor a pytorch module in a single forward. + + @param module: A torch module + @param file_name: The file_name to store or load file + @param compare_fn: A custom fn used to compare the results manually. + + >>> def compare_fn(v1, v2, key, type): + >>> return None + + v1 is the baseline value + v2 is the value of current version + key is the key of submodules + type is in one of 'input', 'output' + """ + baseline = os.getenv('REGRESSION_BASELINE') + if baseline is None or self.baseline is None: + yield + return + + baseline = self.baseline + io_json = {} + absolute_path = f'./{file_name}.bin' + if not isinstance(module, nn.Module): + assert hasattr(module, 'model') + module = module.model + + hack_forward(module, file_name, io_json) + intercept_module(module, io_json) + yield + hack_forward(module, None, None, restore=True) + intercept_module(module, None, restore=True) + if baseline: + with open(absolute_path, 'wb') as f: + pickle.dump(io_json, f) + self.store(absolute_path, f'{file_name}.bin') + os.remove(absolute_path) + else: + name = os.path.basename(absolute_path) + baseline = os.path.join(tempfile.gettempdir(), name) + self.load(baseline, name) + with open(baseline, 'rb') as f: + baseline_json = pickle.load(f) + + class NumpyEncoder(json.JSONEncoder): + """Special json encoder for numpy types + """ + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') + print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') + if not compare_io_and_print(baseline_json, io_json, compare_fn): + raise ValueError('Result not match!') + + @contextlib.contextmanager + def monitor_module_train(self, + trainer: Union[Dict, Any], + file_name, + level='config', + compare_fn=None, + ignore_keys=None, + compare_random=True, + lazy_stop_callback=None): + """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. + + This is usually useful when you try to change some dangerous code + which has the risk of affecting the training loop. + + @param trainer: A dict or an object contains the model/optimizer/lr_scheduler + @param file_name: The file_name to store or load file + @param level: The regression level. + 'strict' for matching every single tensor. + Please make sure the parameters of head are fixed + and the drop-out rate is zero. + 'config' for matching the initial config, like cfg file, optimizer param_groups, + lr_scheduler params and the random seed. + 'metric' for compare the best metrics in the evaluation loop. + @param compare_fn: A custom fn used to compare the results manually. + @param ignore_keys: The keys to ignore of the named_parameters. + @param compare_random: If to compare random setttings, default True. + @param lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called. + + >>> def compare_fn(v1, v2, key, type): + >>> return None + + v1 is the baseline value + v2 is the value of current version + key is the key of modules/parameters + type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state' + """ + baseline = os.getenv('REGRESSION_BASELINE') + if baseline is None or self.baseline is None: + yield + return + + baseline = self.baseline + + io_json = {} + bw_json = {} + absolute_path = f'./{file_name}.bin' + + if level == 'strict': + print( + "[Important] The level of regression is 'strict', please make sure your model's parameters are " + 'fixed and all drop-out rates have been set to zero.') + + assert hasattr( + trainer, 'model') or 'model' in trainer, 'model must be in trainer' + module = trainer['model'] if isinstance(trainer, + dict) else trainer.model + if not isinstance(module, nn.Module): + assert hasattr(module, 'model') + module = module.model + + assert hasattr( + trainer, 'optimizer' + ) or 'optimizer' in trainer, 'optimizer must be in trainer' + assert hasattr( + trainer, 'lr_scheduler' + ) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer' + optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance( + trainer, dict) else trainer.optimizer + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \ + else trainer.lr_scheduler + torch_state = numpify_tensor_nested(torch.get_rng_state()) + np_state = np.random.get_state() + random_seed = random.getstate() + seed = trainer._seed if hasattr( + trainer, + '_seed') else trainer.seed if hasattr(trainer, 'seed') else None + + if level == 'strict': + hack_forward(module, file_name, io_json) + intercept_module(module, io_json) + hack_backward( + module, optimizer, bw_json, lazy_stop_callback=lazy_stop_callback) + yield + hack_backward(module, optimizer, None, restore=True) + if level == 'strict': + hack_forward(module, None, None, restore=True) + intercept_module(module, None, restore=True) + + optimizer_dict = optimizer.state_dict() + optimizer_dict.pop('state', None) + summary = { + 'forward': io_json, + 'backward': bw_json, + 'optimizer': { + 'type': optimizer.__class__.__name__, + 'defaults': optimizer.defaults, + 'state_dict': optimizer_dict + }, + 'lr_scheduler': { + 'type': lr_scheduler.__class__.__name__, + 'state_dict': lr_scheduler.state_dict() + }, + 'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None, + 'state': { + 'torch_state': torch_state, + 'np_state': np_state, + 'random_seed': random_seed, + 'seed': seed, + } + } + + if baseline: + with open(absolute_path, 'wb') as f: + pickle.dump(summary, f) + self.store(absolute_path, f'{file_name}.bin') + os.remove(absolute_path) + else: + name = os.path.basename(absolute_path) + baseline = os.path.join(tempfile.gettempdir(), name) + self.load(baseline, name) + with open(baseline, 'rb') as f: + baseline_json = pickle.load(f) + + if level == 'strict' and not compare_io_and_print( + baseline_json['forward'], io_json, compare_fn): + raise RuntimeError('Forward not match!') + if not compare_backward_and_print( + baseline_json['backward'], + bw_json, + compare_fn=compare_fn, + ignore_keys=ignore_keys, + level=level): + raise RuntimeError('Backward not match!') + cfg_opt1 = { + 'optimizer': baseline_json['optimizer'], + 'lr_scheduler': baseline_json['lr_scheduler'], + 'cfg': baseline_json['cfg'], + 'state': None if not compare_random else baseline_json['state'] + } + cfg_opt2 = { + 'optimizer': summary['optimizer'], + 'lr_scheduler': summary['lr_scheduler'], + 'cfg': summary['cfg'], + 'state': None if not compare_random else summary['state'] + } + if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): + raise RuntimeError('Cfg or optimizers not match!') + + +class MsRegressTool(RegressTool): + + class EarlyStopError(Exception): + pass + + @contextlib.contextmanager + def monitor_ms_train(self, + trainer, + file_name, + level='config', + compare_fn=None, + ignore_keys=None): + + def lazy_stop_callback(): + + from modelscope.trainers.hooks.hook import Hook, Priority + + class EarlyStopHook(Hook): + PRIORITY = Priority.VERY_LOW + + def after_iter(self, trainer): + raise MsRegressTool.EarlyStopError('Test finished.') + + trainer.register_hook(EarlyStopHook()) + + def _train_loop(trainer, *args, **kwargs): + with self.monitor_module_train( + trainer, + file_name, + level, + compare_fn=compare_fn, + ignore_keys=ignore_keys, + lazy_stop_callback=lazy_stop_callback): + try: + return trainer.train_loop_origin(*args, **kwargs) + except MsRegressTool.EarlyStopError: + pass + + trainer.train_loop_origin, trainer.train_loop = \ + trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer) + yield + + +def compare_module(module1: nn.Module, module2: nn.Module): + for p1, p2 in zip(module1.parameters(), module2.parameters()): + if p1.data.ne(p2.data).sum() > 0: + return False + return True + + +def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): + import torch + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)( + numpify_tensor_nested(t, reduction, clip_value) for t in tensors) + if isinstance(tensors, Mapping): + return type(tensors)({ + k: numpify_tensor_nested(t, reduction, clip_value) + for k, t in tensors.items() + }) + if isinstance(tensors, torch.Tensor): + t: np.ndarray = tensors.cpu().numpy() + if clip_value is not None: + t = np.where(t > clip_value, clip_value, t) + t = np.where(t < -clip_value, -clip_value, t) + if reduction == 'sum': + return t.sum(dtype=np.float) + elif reduction == 'mean': + return t.mean(dtype=np.float) + return t + return tensors + + +def detach_tensor_nested(tensors): + import torch + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(detach_tensor_nested(t) for t in tensors) + if isinstance(tensors, Mapping): + return type(tensors)( + {k: detach_tensor_nested(t) + for k, t in tensors.items()}) + if isinstance(tensors, torch.Tensor): + return tensors.detach() + return tensors + + +def hack_forward(module: nn.Module, + name, + io_json, + restore=False, + keep_tensors=False): + + def _forward(self, *args, **kwargs): + ret = self.forward_origin(*args, **kwargs) + if keep_tensors: + args = numpify_tensor_nested(detach_tensor_nested(args)) + kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs)) + output = numpify_tensor_nested(detach_tensor_nested(ret)) + else: + args = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(args), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(args), reduction='mean'), + } + kwargs = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(kwargs), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(kwargs), reduction='mean'), + } + output = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(ret), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(ret), reduction='mean'), + } + + io_json[name] = { + 'input': { + 'args': args, + 'kwargs': kwargs, + }, + 'output': output, + } + return ret + + if not restore and not hasattr(module, 'forward_origin'): + module.forward_origin, module.forward = module.forward, type( + module.forward)(_forward, module) + if restore and hasattr(module, 'forward_origin'): + module.forward = module.forward_origin + del module.forward_origin + + +def hack_backward(module: nn.Module, + optimizer, + io_json, + restore=False, + lazy_stop_callback=None): + + def _step(self, *args, **kwargs): + for name, param in module.named_parameters(): + io_json[name] = { + 'data': { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='mean'), + }, + 'grad': { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.grad), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.grad), reduction='mean'), + } + } + ret = self.step_origin(*args, **kwargs) + for name, param in module.named_parameters(): + io_json[name]['data_after'] = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='mean'), + } + if lazy_stop_callback is not None: + lazy_stop_callback() + return ret + + if not restore and not hasattr(optimizer, 'step_origin'): + optimizer.step_origin, optimizer.step = optimizer.step, type( + optimizer.state_dict)(_step, optimizer) + if restore and hasattr(optimizer, 'step_origin'): + optimizer.step = optimizer.step_origin + del optimizer.step_origin + + +def intercept_module(module: nn.Module, + io_json, + parent_name=None, + restore=False): + for name, module in module.named_children(): + full_name = parent_name + '.' + name if parent_name is not None else name + hack_forward(module, full_name, io_json, restore) + intercept_module(module, io_json, full_name, restore) + + +def compare_arguments_nested(print_content, arg1, arg2): + type1 = type(arg1) + type2 = type(arg2) + if type1.__name__ != type2.__name__: + if print_content is not None: + print( + f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}' + ) + return False + + if arg1 is None: + return True + elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)): + if arg1 != arg2: + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (float, np.floating)): + if not np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, equal_nan=True): + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (tuple, list)): + if len(arg1) != len(arg2): + if print_content is not None: + print( + f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}' + ) + return False + if not all([ + compare_arguments_nested(None, sub_arg1, sub_arg2) + for sub_arg1, sub_arg2 in zip(arg1, arg2) + ]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, Mapping): + keys1 = arg1.keys() + keys2 = arg2.keys() + if len(keys1) != len(keys2): + if print_content is not None: + print( + f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}' + ) + return False + if len(set(keys1) - set(keys2)) > 0: + if print_content is not None: + print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') + return False + if not all([ + compare_arguments_nested(None, arg1[key], arg2[key]) + for key in keys1 + ]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, np.ndarray): + arg1 = np.where(np.equal(arg1, None), np.NaN, + arg1).astype(dtype=np.float) + arg2 = np.where(np.equal(arg2, None), np.NaN, + arg2).astype(dtype=np.float) + if not all( + np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, + equal_nan=True).flatten()): + if print_content is not None: + print(f'{print_content}') + return False + return True + else: + raise ValueError(f'type not supported: {type1}') + + +def compare_io_and_print(baseline_json, io_json, compare_fn=None): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + keys1 = set(baseline_json.keys()) + keys2 = set(io_json.keys()) + added = keys1 - keys2 + removed = keys2 - keys1 + print(f'unmatched keys: {added}, {removed}') + shared_keys = keys1.intersection(keys2) + match = True + for key in shared_keys: + v1 = baseline_json[key] + v2 = io_json[key] + + v1input = numpify_tensor_nested(v1['input']) + v2input = numpify_tensor_nested(v2['input']) + res = compare_fn(v1input, v2input, key, 'input') + if res is not None: + print( + f'input of {key} compared with user compare_fn with result:{res}\n' + ) + match = match and res + else: + match = compare_arguments_nested( + f'unmatched module {key} input args', v1input['args'], + v2input['args']) and match + match = compare_arguments_nested( + f'unmatched module {key} input kwargs', v1input['kwargs'], + v2input['kwargs']) and match + v1output = numpify_tensor_nested(v1['output']) + v2output = numpify_tensor_nested(v2['output']) + res = compare_fn(v1output, v2output, key, 'output') + if res is not None: + print( + f'output of {key} compared with user compare_fn with result:{res}\n' + ) + match = match and res + else: + match = compare_arguments_nested(f'unmatched module {key} outputs', + v1output, v2output) and match + return match + + +def compare_backward_and_print(baseline_json, + bw_json, + level, + ignore_keys=None, + compare_fn=None): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + keys1 = set(baseline_json.keys()) + keys2 = set(bw_json.keys()) + added = keys1 - keys2 + removed = keys2 - keys1 + print(f'unmatched backward keys: {added}, {removed}') + shared_keys = keys1.intersection(keys2) + match = True + for key in shared_keys: + if ignore_keys is not None and key in ignore_keys: + continue + + res = compare_fn(baseline_json[key], bw_json[key], key, 'backward') + if res is not None: + print(f'backward data of {key} compared with ' + f'user compare_fn with result:{res}\n') + match = match and res + else: + data1, grad1, data_after1 = baseline_json[key][ + 'data'], baseline_json[key]['grad'], baseline_json[key][ + 'data_after'] + data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ + 'grad'], bw_json[key]['data_after'] + match = compare_arguments_nested( + f'unmatched module {key} tensor data', data1, data2) and match + if level == 'strict': + match = compare_arguments_nested( + f'unmatched module {key} grad data', grad1, + grad2) and match + match = compare_arguments_nested( + f'unmatched module {key} data after step', data_after1, + data_after2) and match + return match + + +def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[ + 'optimizer'], baseline_json['lr_scheduler'], baseline_json[ + 'cfg'], baseline_json['state'] + optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[ + 'lr_scheduler'], cfg_json['cfg'], baseline_json['state'] + + match = True + res = compare_fn(optimizer1, optimizer2, None, 'optimizer') + if res is not None: + print(f'optimizer compared with user compare_fn with result:{res}\n') + match = match and res + else: + if optimizer1['type'] != optimizer2['type']: + print( + f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" + ) + match = compare_arguments_nested('unmatched optimizer defaults', + optimizer1['defaults'], + optimizer2['defaults']) and match + match = compare_arguments_nested('unmatched optimizer state_dict', + optimizer1['state_dict'], + optimizer2['state_dict']) and match + + res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') + if res is not None: + print( + f'lr_scheduler compared with user compare_fn with result:{res}\n') + match = match and res + else: + if lr_scheduler1['type'] != lr_scheduler2['type']: + print( + f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" + ) + match = compare_arguments_nested('unmatched lr_scheduler state_dict', + lr_scheduler1['state_dict'], + lr_scheduler2['state_dict']) and match + + res = compare_fn(cfg1, cfg2, None, 'cfg') + if res is not None: + print(f'cfg compared with user compare_fn with result:{res}\n') + match = match and res + else: + match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match + + res = compare_fn(state1, state2, None, 'state') + if res is not None: + print( + f'random state compared with user compare_fn with result:{res}\n') + match = match and res + else: + match = compare_arguments_nested('unmatched random state', state1, + state2) and match + + return match diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py index 2f57b2d8..1b709e27 100644 --- a/tests/pipelines/test_fill_mask.py +++ b/tests/pipelines/test_fill_mask.py @@ -9,6 +9,7 @@ from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import FillMaskPipeline from modelscope.preprocessors import FillMaskPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import MsRegressTool from modelscope.utils.test_utils import test_level @@ -37,6 +38,7 @@ class FillMaskTest(unittest.TestCase): 'Everything in [MASK] you call reality is really [MASK] a reflection of your ' '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' } + regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): @@ -98,9 +100,11 @@ class FillMaskTest(unittest.TestCase): second_sequence=None) pipeline_ins = pipeline( task=Tasks.fill_mask, model=model, preprocessor=preprocessor) - print( - f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' - f'{pipeline_ins(self.test_inputs[language])}\n') + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, f'fill_mask_sbert_{language}'): + print( + f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' + f'{pipeline_ins(self.test_inputs[language])}\n') # veco model = Model.from_pretrained(self.model_id_veco) @@ -111,8 +115,11 @@ class FillMaskTest(unittest.TestCase): for language in ['zh', 'en']: ori_text = self.ori_texts[language] test_input = self.test_inputs[language].replace('[MASK]', '') - print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' - f'{pipeline_ins(test_input)}\n') + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, f'fill_mask_veco_{language}'): + print( + f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') # zh bert model = Model.from_pretrained(self.model_id_bert) @@ -123,8 +130,10 @@ class FillMaskTest(unittest.TestCase): language = 'zh' ori_text = self.ori_texts[language] test_input = self.test_inputs[language] - print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' - f'{pipeline_ins(test_input)}\n') + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'fill_mask_bert_zh'): + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py index 1e259a2e..1d3fba12 100644 --- a/tests/pipelines/test_nli.py +++ b/tests/pipelines/test_nli.py @@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import PairSentenceClassificationPipeline from modelscope.preprocessors import PairSentenceClassificationPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import MsRegressTool from modelscope.utils.test_utils import test_level @@ -15,6 +16,7 @@ class NLITest(unittest.TestCase): model_id = 'damo/nlp_structbert_nli_chinese-base' sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' sentence2 = '四川商务职业学院商务管理在哪个校区?' + regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_file_download(self): @@ -26,7 +28,6 @@ class NLITest(unittest.TestCase): pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') - print() print( f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') @@ -42,7 +43,9 @@ class NLITest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) - print(pipeline_ins(input=(self.sentence1, self.sentence2))) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'sbert_nli'): + print(pipeline_ins(input=(self.sentence1, self.sentence2))) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py index d39f6783..6990bf75 100644 --- a/tests/pipelines/test_sentence_similarity.py +++ b/tests/pipelines/test_sentence_similarity.py @@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import PairSentenceClassificationPipeline from modelscope.preprocessors import PairSentenceClassificationPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import MsRegressTool from modelscope.utils.test_utils import test_level @@ -15,6 +16,7 @@ class SentenceSimilarityTest(unittest.TestCase): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' sentence1 = '今天气温比昨天高么?' sentence2 = '今天湿度比昨天高么?' + regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): @@ -47,7 +49,9 @@ class SentenceSimilarityTest(unittest.TestCase): def test_run_with_model_name(self): pipeline_ins = pipeline( task=Tasks.sentence_similarity, model=self.model_id) - print(pipeline_ins(input=(self.sentence1, self.sentence2))) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'sbert_sen_sim'): + print(pipeline_ins(input=(self.sentence1, self.sentence2))) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py index f3bc6981..35c96282 100644 --- a/tests/pipelines/test_sentiment_classification.py +++ b/tests/pipelines/test_sentiment_classification.py @@ -30,7 +30,6 @@ class SentimentClassificationTaskModelTest(unittest.TestCase): preprocessor=tokenizer) print(f'sentence1: {self.sentence1}\n' f'pipeline1:{pipeline1(input=self.sentence1)}') - print() print(f'sentence1: {self.sentence1}\n' f'pipeline1: {pipeline2(input=self.sentence1)}') diff --git a/tests/pipelines/test_word_segmentation.py b/tests/pipelines/test_word_segmentation.py index c332d987..87006f96 100644 --- a/tests/pipelines/test_word_segmentation.py +++ b/tests/pipelines/test_word_segmentation.py @@ -9,6 +9,7 @@ from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import WordSegmentationPipeline from modelscope.preprocessors import TokenClassificationPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import MsRegressTool from modelscope.utils.test_utils import test_level @@ -16,6 +17,7 @@ class WordSegmentationTest(unittest.TestCase): model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' sentence = '今天天气不错,适合出去游玩' sentence_eng = 'I am a program.' + regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): @@ -27,7 +29,6 @@ class WordSegmentationTest(unittest.TestCase): Tasks.word_segmentation, model=model, preprocessor=tokenizer) print(f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence)}') - print() print(f'pipeline2: {pipeline2(input=self.sentence)}') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -42,8 +43,12 @@ class WordSegmentationTest(unittest.TestCase): def test_run_with_model_name(self): pipeline_ins = pipeline( task=Tasks.word_segmentation, model=self.model_id) - print(pipeline_ins(input=self.sentence)) - print(pipeline_ins(input=self.sentence_eng)) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'sbert_ws_zh'): + print(pipeline_ins(input=self.sentence)) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'sbert_ws_en'): + print(pipeline_ins(input=self.sentence_eng)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py index 7620a0ed..f0f2a481 100644 --- a/tests/pipelines/test_zero_shot_classification.py +++ b/tests/pipelines/test_zero_shot_classification.py @@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import ZeroShotClassificationPipeline from modelscope.preprocessors import ZeroShotClassificationPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import MsRegressTool from modelscope.utils.test_utils import test_level @@ -16,6 +17,7 @@ class ZeroShotClassificationTest(unittest.TestCase): sentence = '全新突破 解放军运20版空中加油机曝光' labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] template = '这篇文章的标题是{}' + regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_file_download(self): @@ -33,7 +35,6 @@ class ZeroShotClassificationTest(unittest.TestCase): f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' ) - print() print( f'sentence: {self.sentence}\n' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' @@ -53,7 +54,11 @@ class ZeroShotClassificationTest(unittest.TestCase): def test_run_with_model_name(self): pipeline_ins = pipeline( task=Tasks.zero_shot_classification, model=self.model_id) - print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, 'sbert_zero_shot'): + print( + pipeline_ins( + input=self.sentence, candidate_labels=self.labels)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): diff --git a/tests/run.py b/tests/run.py index 1a601eda..79509745 100644 --- a/tests/run.py +++ b/tests/run.py @@ -334,6 +334,7 @@ if __name__ == '__main__': help='Save result to directory, internal use only') args = parser.parse_args() set_test_level(args.level) + os.environ['REGRESSION_BASELINE'] = '1' logger.info(f'TEST LEVEL: {test_level()}') if not args.disable_profile: from utils import profiler