Browse Source

[to #42322933] add regress tests

Add regression test for some unit tests.
Firstly, Run a baseline test to create a pickle file which contains the inputs and outputs of modules, then changes can be observed between
the latest version and the baseline file.
Some baseline files are submitted in the data/test/regression folder
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9814693
master
yuze.zyz 3 years ago
parent
commit
fbde374659
19 changed files with 777 additions and 16 deletions
  1. +2
    -0
      .gitattributes
  2. +3
    -0
      data/test/regression/fill_mask_bert_zh.bin
  3. +3
    -0
      data/test/regression/fill_mask_sbert_en.bin
  4. +3
    -0
      data/test/regression/fill_mask_sbert_zh.bin
  5. +3
    -0
      data/test/regression/fill_mask_veco_en.bin
  6. +3
    -0
      data/test/regression/fill_mask_veco_zh.bin
  7. +3
    -0
      data/test/regression/sbert_nli.bin
  8. +3
    -0
      data/test/regression/sbert_sen_sim.bin
  9. +3
    -0
      data/test/regression/sbert_ws_en.bin
  10. +3
    -0
      data/test/regression/sbert_ws_zh.bin
  11. +3
    -0
      data/test/regression/sbert_zero_shot.bin
  12. +703
    -0
      modelscope/utils/regress_test_utils.py
  13. +16
    -7
      tests/pipelines/test_fill_mask.py
  14. +5
    -2
      tests/pipelines/test_nli.py
  15. +5
    -1
      tests/pipelines/test_sentence_similarity.py
  16. +0
    -1
      tests/pipelines/test_sentiment_classification.py
  17. +8
    -3
      tests/pipelines/test_word_segmentation.py
  18. +7
    -2
      tests/pipelines/test_zero_shot_classification.py
  19. +1
    -0
      tests/run.py

+ 2
- 0
.gitattributes View File

@@ -4,4 +4,6 @@
*.wav filter=lfs diff=lfs merge=lfs -text *.wav filter=lfs diff=lfs merge=lfs -text
*.JPEG filter=lfs diff=lfs merge=lfs -text *.JPEG filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text *.jpeg filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.avi filter=lfs diff=lfs merge=lfs -text *.avi filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text

+ 3
- 0
data/test/regression/fill_mask_bert_zh.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:541183383bb06aa3ca2c44a68cd51c1be5e3e984a1dee2c58092b9552660f3ce
size 61883

+ 3
- 0
data/test/regression/fill_mask_sbert_en.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8f0afcd9d2aa5ac9569114203bd9db4f1a520c903a88fd4854370cdde0e7eab7
size 119940

+ 3
- 0
data/test/regression/fill_mask_sbert_zh.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280
size 119940

+ 3
- 0
data/test/regression/fill_mask_veco_en.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705
size 119619

+ 3
- 0
data/test/regression/fill_mask_veco_zh.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a
size 119619

+ 3
- 0
data/test/regression/sbert_nli.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62
size 62231

+ 3
- 0
data/test/regression/sbert_sen_sim.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a
size 62235

+ 3
- 0
data/test/regression/sbert_ws_en.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572
size 60801

+ 3
- 0
data/test/regression/sbert_ws_zh.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c
size 60801

+ 3
- 0
data/test/regression/sbert_zero_shot.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85
size 61589

+ 703
- 0
modelscope/utils/regress_test_utils.py View File

@@ -0,0 +1,703 @@
import contextlib
import hashlib
import os
import pickle
import random
import shutil
import tempfile
from collections.abc import Mapping
from pathlib import Path
from types import FunctionType
from typing import Any, Dict, Union

import json
import numpy as np
import torch.optim
from torch import nn


class RegressTool:
"""This class is used to stop inference/training results from changing by some unaware affections by unittests.

Firstly, run a baseline test to create a result file, then changes can be observed between
the latest version and the baseline file.
"""

def __init__(self,
baseline: bool = None,
store_func: FunctionType = None,
load_func: FunctionType = None):
"""A func to store the baseline file and a func to load the baseline file.
"""
self.baseline = baseline
self.store_func = store_func
self.load_func = load_func
print(f'Current working dir is: {Path.cwd()}')

def store(self, local, remote):
if self.store_func is not None:
self.store_func(local, remote)
else:
path = os.path.abspath(
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
os.makedirs(path, exist_ok=True)
shutil.copy(local, os.path.join(path, remote))

def load(self, local, remote):
if self.load_func is not None:
self.load_func(local, remote)
else:
path = os.path.abspath(
os.path.join(Path.cwd(), 'data', 'test', 'regression'))
baseline = os.path.join(path, remote)
if not os.path.exists(baseline):
raise ValueError(f'base line file {baseline} not exist')
print(
f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}'
)
if os.path.exists(local):
os.remove(local)
os.symlink(baseline, local, target_is_directory=False)

@contextlib.contextmanager
def monitor_module_single_forward(self,
module: nn.Module,
file_name: str,
compare_fn=None):
"""Monitor a pytorch module in a single forward.

@param module: A torch module
@param file_name: The file_name to store or load file
@param compare_fn: A custom fn used to compare the results manually.

>>> def compare_fn(v1, v2, key, type):
>>> return None

v1 is the baseline value
v2 is the value of current version
key is the key of submodules
type is in one of 'input', 'output'
"""
baseline = os.getenv('REGRESSION_BASELINE')
if baseline is None or self.baseline is None:
yield
return

baseline = self.baseline
io_json = {}
absolute_path = f'./{file_name}.bin'
if not isinstance(module, nn.Module):
assert hasattr(module, 'model')
module = module.model

hack_forward(module, file_name, io_json)
intercept_module(module, io_json)
yield
hack_forward(module, None, None, restore=True)
intercept_module(module, None, restore=True)
if baseline:
with open(absolute_path, 'wb') as f:
pickle.dump(io_json, f)
self.store(absolute_path, f'{file_name}.bin')
os.remove(absolute_path)
else:
name = os.path.basename(absolute_path)
baseline = os.path.join(tempfile.gettempdir(), name)
self.load(baseline, name)
with open(baseline, 'rb') as f:
baseline_json = pickle.load(f)

class NumpyEncoder(json.JSONEncoder):
"""Special json encoder for numpy types
"""

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}')
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}')
if not compare_io_and_print(baseline_json, io_json, compare_fn):
raise ValueError('Result not match!')

@contextlib.contextmanager
def monitor_module_train(self,
trainer: Union[Dict, Any],
file_name,
level='config',
compare_fn=None,
ignore_keys=None,
compare_random=True,
lazy_stop_callback=None):
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.

This is usually useful when you try to change some dangerous code
which has the risk of affecting the training loop.

@param trainer: A dict or an object contains the model/optimizer/lr_scheduler
@param file_name: The file_name to store or load file
@param level: The regression level.
'strict' for matching every single tensor.
Please make sure the parameters of head are fixed
and the drop-out rate is zero.
'config' for matching the initial config, like cfg file, optimizer param_groups,
lr_scheduler params and the random seed.
'metric' for compare the best metrics in the evaluation loop.
@param compare_fn: A custom fn used to compare the results manually.
@param ignore_keys: The keys to ignore of the named_parameters.
@param compare_random: If to compare random setttings, default True.
@param lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called.

>>> def compare_fn(v1, v2, key, type):
>>> return None

v1 is the baseline value
v2 is the value of current version
key is the key of modules/parameters
type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state'
"""
baseline = os.getenv('REGRESSION_BASELINE')
if baseline is None or self.baseline is None:
yield
return

baseline = self.baseline

io_json = {}
bw_json = {}
absolute_path = f'./{file_name}.bin'

if level == 'strict':
print(
"[Important] The level of regression is 'strict', please make sure your model's parameters are "
'fixed and all drop-out rates have been set to zero.')

assert hasattr(
trainer, 'model') or 'model' in trainer, 'model must be in trainer'
module = trainer['model'] if isinstance(trainer,
dict) else trainer.model
if not isinstance(module, nn.Module):
assert hasattr(module, 'model')
module = module.model

assert hasattr(
trainer, 'optimizer'
) or 'optimizer' in trainer, 'optimizer must be in trainer'
assert hasattr(
trainer, 'lr_scheduler'
) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer'
optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance(
trainer, dict) else trainer.optimizer
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \
else trainer.lr_scheduler
torch_state = numpify_tensor_nested(torch.get_rng_state())
np_state = np.random.get_state()
random_seed = random.getstate()
seed = trainer._seed if hasattr(
trainer,
'_seed') else trainer.seed if hasattr(trainer, 'seed') else None

if level == 'strict':
hack_forward(module, file_name, io_json)
intercept_module(module, io_json)
hack_backward(
module, optimizer, bw_json, lazy_stop_callback=lazy_stop_callback)
yield
hack_backward(module, optimizer, None, restore=True)
if level == 'strict':
hack_forward(module, None, None, restore=True)
intercept_module(module, None, restore=True)

optimizer_dict = optimizer.state_dict()
optimizer_dict.pop('state', None)
summary = {
'forward': io_json,
'backward': bw_json,
'optimizer': {
'type': optimizer.__class__.__name__,
'defaults': optimizer.defaults,
'state_dict': optimizer_dict
},
'lr_scheduler': {
'type': lr_scheduler.__class__.__name__,
'state_dict': lr_scheduler.state_dict()
},
'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None,
'state': {
'torch_state': torch_state,
'np_state': np_state,
'random_seed': random_seed,
'seed': seed,
}
}

if baseline:
with open(absolute_path, 'wb') as f:
pickle.dump(summary, f)
self.store(absolute_path, f'{file_name}.bin')
os.remove(absolute_path)
else:
name = os.path.basename(absolute_path)
baseline = os.path.join(tempfile.gettempdir(), name)
self.load(baseline, name)
with open(baseline, 'rb') as f:
baseline_json = pickle.load(f)

if level == 'strict' and not compare_io_and_print(
baseline_json['forward'], io_json, compare_fn):
raise RuntimeError('Forward not match!')
if not compare_backward_and_print(
baseline_json['backward'],
bw_json,
compare_fn=compare_fn,
ignore_keys=ignore_keys,
level=level):
raise RuntimeError('Backward not match!')
cfg_opt1 = {
'optimizer': baseline_json['optimizer'],
'lr_scheduler': baseline_json['lr_scheduler'],
'cfg': baseline_json['cfg'],
'state': None if not compare_random else baseline_json['state']
}
cfg_opt2 = {
'optimizer': summary['optimizer'],
'lr_scheduler': summary['lr_scheduler'],
'cfg': summary['cfg'],
'state': None if not compare_random else summary['state']
}
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn):
raise RuntimeError('Cfg or optimizers not match!')


class MsRegressTool(RegressTool):

class EarlyStopError(Exception):
pass

@contextlib.contextmanager
def monitor_ms_train(self,
trainer,
file_name,
level='config',
compare_fn=None,
ignore_keys=None):

def lazy_stop_callback():

from modelscope.trainers.hooks.hook import Hook, Priority

class EarlyStopHook(Hook):
PRIORITY = Priority.VERY_LOW

def after_iter(self, trainer):
raise MsRegressTool.EarlyStopError('Test finished.')

trainer.register_hook(EarlyStopHook())

def _train_loop(trainer, *args, **kwargs):
with self.monitor_module_train(
trainer,
file_name,
level,
compare_fn=compare_fn,
ignore_keys=ignore_keys,
lazy_stop_callback=lazy_stop_callback):
try:
return trainer.train_loop_origin(*args, **kwargs)
except MsRegressTool.EarlyStopError:
pass

trainer.train_loop_origin, trainer.train_loop = \
trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer)
yield


def compare_module(module1: nn.Module, module2: nn.Module):
for p1, p2 in zip(module1.parameters(), module2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True


def numpify_tensor_nested(tensors, reduction=None, clip_value=10000):
import torch
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(
numpify_tensor_nested(t, reduction, clip_value) for t in tensors)
if isinstance(tensors, Mapping):
return type(tensors)({
k: numpify_tensor_nested(t, reduction, clip_value)
for k, t in tensors.items()
})
if isinstance(tensors, torch.Tensor):
t: np.ndarray = tensors.cpu().numpy()
if clip_value is not None:
t = np.where(t > clip_value, clip_value, t)
t = np.where(t < -clip_value, -clip_value, t)
if reduction == 'sum':
return t.sum(dtype=np.float)
elif reduction == 'mean':
return t.mean(dtype=np.float)
return t
return tensors


def detach_tensor_nested(tensors):
import torch
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(detach_tensor_nested(t) for t in tensors)
if isinstance(tensors, Mapping):
return type(tensors)(
{k: detach_tensor_nested(t)
for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
return tensors.detach()
return tensors


def hack_forward(module: nn.Module,
name,
io_json,
restore=False,
keep_tensors=False):

def _forward(self, *args, **kwargs):
ret = self.forward_origin(*args, **kwargs)
if keep_tensors:
args = numpify_tensor_nested(detach_tensor_nested(args))
kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs))
output = numpify_tensor_nested(detach_tensor_nested(ret))
else:
args = {
'sum':
numpify_tensor_nested(
detach_tensor_nested(args), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(args), reduction='mean'),
}
kwargs = {
'sum':
numpify_tensor_nested(
detach_tensor_nested(kwargs), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(kwargs), reduction='mean'),
}
output = {
'sum':
numpify_tensor_nested(
detach_tensor_nested(ret), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(ret), reduction='mean'),
}

io_json[name] = {
'input': {
'args': args,
'kwargs': kwargs,
},
'output': output,
}
return ret

if not restore and not hasattr(module, 'forward_origin'):
module.forward_origin, module.forward = module.forward, type(
module.forward)(_forward, module)
if restore and hasattr(module, 'forward_origin'):
module.forward = module.forward_origin
del module.forward_origin


def hack_backward(module: nn.Module,
optimizer,
io_json,
restore=False,
lazy_stop_callback=None):

def _step(self, *args, **kwargs):
for name, param in module.named_parameters():
io_json[name] = {
'data': {
'sum':
numpify_tensor_nested(
detach_tensor_nested(param.data), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(param.data), reduction='mean'),
},
'grad': {
'sum':
numpify_tensor_nested(
detach_tensor_nested(param.grad), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(param.grad), reduction='mean'),
}
}
ret = self.step_origin(*args, **kwargs)
for name, param in module.named_parameters():
io_json[name]['data_after'] = {
'sum':
numpify_tensor_nested(
detach_tensor_nested(param.data), reduction='sum'),
'mean':
numpify_tensor_nested(
detach_tensor_nested(param.data), reduction='mean'),
}
if lazy_stop_callback is not None:
lazy_stop_callback()
return ret

if not restore and not hasattr(optimizer, 'step_origin'):
optimizer.step_origin, optimizer.step = optimizer.step, type(
optimizer.state_dict)(_step, optimizer)
if restore and hasattr(optimizer, 'step_origin'):
optimizer.step = optimizer.step_origin
del optimizer.step_origin


def intercept_module(module: nn.Module,
io_json,
parent_name=None,
restore=False):
for name, module in module.named_children():
full_name = parent_name + '.' + name if parent_name is not None else name
hack_forward(module, full_name, io_json, restore)
intercept_module(module, io_json, full_name, restore)


def compare_arguments_nested(print_content, arg1, arg2):
type1 = type(arg1)
type2 = type(arg2)
if type1.__name__ != type2.__name__:
if print_content is not None:
print(
f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}'
)
return False

if arg1 is None:
return True
elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)):
if arg1 != arg2:
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (float, np.floating)):
if not np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, equal_nan=True):
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (tuple, list)):
if len(arg1) != len(arg2):
if print_content is not None:
print(
f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}'
)
return False
if not all([
compare_arguments_nested(None, sub_arg1, sub_arg2)
for sub_arg1, sub_arg2 in zip(arg1, arg2)
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, Mapping):
keys1 = arg1.keys()
keys2 = arg2.keys()
if len(keys1) != len(keys2):
if print_content is not None:
print(
f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}'
)
return False
if len(set(keys1) - set(keys2)) > 0:
if print_content is not None:
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
return False
if not all([
compare_arguments_nested(None, arg1[key], arg2[key])
for key in keys1
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, np.ndarray):
arg1 = np.where(np.equal(arg1, None), np.NaN,
arg1).astype(dtype=np.float)
arg2 = np.where(np.equal(arg2, None), np.NaN,
arg2).astype(dtype=np.float)
if not all(
np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8,
equal_nan=True).flatten()):
if print_content is not None:
print(f'{print_content}')
return False
return True
else:
raise ValueError(f'type not supported: {type1}')


def compare_io_and_print(baseline_json, io_json, compare_fn=None):
if compare_fn is None:

def compare_fn(*args, **kwargs):
return None

keys1 = set(baseline_json.keys())
keys2 = set(io_json.keys())
added = keys1 - keys2
removed = keys2 - keys1
print(f'unmatched keys: {added}, {removed}')
shared_keys = keys1.intersection(keys2)
match = True
for key in shared_keys:
v1 = baseline_json[key]
v2 = io_json[key]

v1input = numpify_tensor_nested(v1['input'])
v2input = numpify_tensor_nested(v2['input'])
res = compare_fn(v1input, v2input, key, 'input')
if res is not None:
print(
f'input of {key} compared with user compare_fn with result:{res}\n'
)
match = match and res
else:
match = compare_arguments_nested(
f'unmatched module {key} input args', v1input['args'],
v2input['args']) and match
match = compare_arguments_nested(
f'unmatched module {key} input kwargs', v1input['kwargs'],
v2input['kwargs']) and match
v1output = numpify_tensor_nested(v1['output'])
v2output = numpify_tensor_nested(v2['output'])
res = compare_fn(v1output, v2output, key, 'output')
if res is not None:
print(
f'output of {key} compared with user compare_fn with result:{res}\n'
)
match = match and res
else:
match = compare_arguments_nested(f'unmatched module {key} outputs',
v1output, v2output) and match
return match


def compare_backward_and_print(baseline_json,
bw_json,
level,
ignore_keys=None,
compare_fn=None):
if compare_fn is None:

def compare_fn(*args, **kwargs):
return None

keys1 = set(baseline_json.keys())
keys2 = set(bw_json.keys())
added = keys1 - keys2
removed = keys2 - keys1
print(f'unmatched backward keys: {added}, {removed}')
shared_keys = keys1.intersection(keys2)
match = True
for key in shared_keys:
if ignore_keys is not None and key in ignore_keys:
continue

res = compare_fn(baseline_json[key], bw_json[key], key, 'backward')
if res is not None:
print(f'backward data of {key} compared with '
f'user compare_fn with result:{res}\n')
match = match and res
else:
data1, grad1, data_after1 = baseline_json[key][
'data'], baseline_json[key]['grad'], baseline_json[key][
'data_after']
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
'grad'], bw_json[key]['data_after']
match = compare_arguments_nested(
f'unmatched module {key} tensor data', data1, data2) and match
if level == 'strict':
match = compare_arguments_nested(
f'unmatched module {key} grad data', grad1,
grad2) and match
match = compare_arguments_nested(
f'unmatched module {key} data after step', data_after1,
data_after2) and match
return match


def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
if compare_fn is None:

def compare_fn(*args, **kwargs):
return None

optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[
'optimizer'], baseline_json['lr_scheduler'], baseline_json[
'cfg'], baseline_json['state']
optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[
'lr_scheduler'], cfg_json['cfg'], baseline_json['state']

match = True
res = compare_fn(optimizer1, optimizer2, None, 'optimizer')
if res is not None:
print(f'optimizer compared with user compare_fn with result:{res}\n')
match = match and res
else:
if optimizer1['type'] != optimizer2['type']:
print(
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
)
match = compare_arguments_nested('unmatched optimizer defaults',
optimizer1['defaults'],
optimizer2['defaults']) and match
match = compare_arguments_nested('unmatched optimizer state_dict',
optimizer1['state_dict'],
optimizer2['state_dict']) and match

res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
if res is not None:
print(
f'lr_scheduler compared with user compare_fn with result:{res}\n')
match = match and res
else:
if lr_scheduler1['type'] != lr_scheduler2['type']:
print(
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
)
match = compare_arguments_nested('unmatched lr_scheduler state_dict',
lr_scheduler1['state_dict'],
lr_scheduler2['state_dict']) and match

res = compare_fn(cfg1, cfg2, None, 'cfg')
if res is not None:
print(f'cfg compared with user compare_fn with result:{res}\n')
match = match and res
else:
match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match

res = compare_fn(state1, state2, None, 'state')
if res is not None:
print(
f'random state compared with user compare_fn with result:{res}\n')
match = match and res
else:
match = compare_arguments_nested('unmatched random state', state1,
state2) and match

return match

+ 16
- 7
tests/pipelines/test_fill_mask.py View File

@@ -9,6 +9,7 @@ from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import FillMaskPipeline from modelscope.pipelines.nlp import FillMaskPipeline
from modelscope.preprocessors import FillMaskPreprocessor from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




@@ -37,6 +38,7 @@ class FillMaskTest(unittest.TestCase):
'Everything in [MASK] you call reality is really [MASK] a reflection of your ' 'Everything in [MASK] you call reality is really [MASK] a reflection of your '
'[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
} }
regress_tool = MsRegressTool(baseline=False)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):
@@ -98,9 +100,11 @@ class FillMaskTest(unittest.TestCase):
second_sequence=None) second_sequence=None)
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor) task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_sbert_{language}'):
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')


# veco # veco
model = Model.from_pretrained(self.model_id_veco) model = Model.from_pretrained(self.model_id_veco)
@@ -111,8 +115,11 @@ class FillMaskTest(unittest.TestCase):
for language in ['zh', 'en']: for language in ['zh', 'en']:
ori_text = self.ori_texts[language] ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>') test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, f'fill_mask_veco_{language}'):
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


# zh bert # zh bert
model = Model.from_pretrained(self.model_id_bert) model = Model.from_pretrained(self.model_id_bert)
@@ -123,8 +130,10 @@ class FillMaskTest(unittest.TestCase):
language = 'zh' language = 'zh'
ori_text = self.ori_texts[language] ori_text = self.ori_texts[language]
test_input = self.test_inputs[language] test_input = self.test_inputs[language]
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'fill_mask_bert_zh'):
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self): def test_run_with_model_name(self):


+ 5
- 2
tests/pipelines/test_nli.py View File

@@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import PairSentenceClassificationPipeline from modelscope.pipelines.nlp import PairSentenceClassificationPipeline
from modelscope.preprocessors import PairSentenceClassificationPreprocessor from modelscope.preprocessors import PairSentenceClassificationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




@@ -15,6 +16,7 @@ class NLITest(unittest.TestCase):
model_id = 'damo/nlp_structbert_nli_chinese-base' model_id = 'damo/nlp_structbert_nli_chinese-base'
sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
sentence2 = '四川商务职业学院商务管理在哪个校区?' sentence2 = '四川商务职业学院商务管理在哪个校区?'
regress_tool = MsRegressTool(baseline=False)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self): def test_run_with_direct_file_download(self):
@@ -26,7 +28,6 @@ class NLITest(unittest.TestCase):
pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer)
print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}')
print()
print( print(
f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}')
@@ -42,7 +43,9 @@ class NLITest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_nli'):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):


+ 5
- 1
tests/pipelines/test_sentence_similarity.py View File

@@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import PairSentenceClassificationPipeline from modelscope.pipelines.nlp import PairSentenceClassificationPipeline
from modelscope.preprocessors import PairSentenceClassificationPreprocessor from modelscope.preprocessors import PairSentenceClassificationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




@@ -15,6 +16,7 @@ class SentenceSimilarityTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
sentence1 = '今天气温比昨天高么?' sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?' sentence2 = '今天湿度比昨天高么?'
regress_tool = MsRegressTool(baseline=False)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self): def test_run(self):
@@ -47,7 +49,9 @@ class SentenceSimilarityTest(unittest.TestCase):
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.sentence_similarity, model=self.model_id) task=Tasks.sentence_similarity, model=self.model_id)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_sen_sim'):
print(pipeline_ins(input=(self.sentence1, self.sentence2)))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):


+ 0
- 1
tests/pipelines/test_sentiment_classification.py View File

@@ -30,7 +30,6 @@ class SentimentClassificationTaskModelTest(unittest.TestCase):
preprocessor=tokenizer) preprocessor=tokenizer)
print(f'sentence1: {self.sentence1}\n' print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{pipeline1(input=self.sentence1)}') f'pipeline1:{pipeline1(input=self.sentence1)}')
print()
print(f'sentence1: {self.sentence1}\n' print(f'sentence1: {self.sentence1}\n'
f'pipeline1: {pipeline2(input=self.sentence1)}') f'pipeline1: {pipeline2(input=self.sentence1)}')




+ 8
- 3
tests/pipelines/test_word_segmentation.py View File

@@ -9,6 +9,7 @@ from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import WordSegmentationPipeline from modelscope.pipelines.nlp import WordSegmentationPipeline
from modelscope.preprocessors import TokenClassificationPreprocessor from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




@@ -16,6 +17,7 @@ class WordSegmentationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
sentence = '今天天气不错,适合出去游玩' sentence = '今天天气不错,适合出去游玩'
sentence_eng = 'I am a program.' sentence_eng = 'I am a program.'
regress_tool = MsRegressTool(baseline=False)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):
@@ -27,7 +29,6 @@ class WordSegmentationTest(unittest.TestCase):
Tasks.word_segmentation, model=model, preprocessor=tokenizer) Tasks.word_segmentation, model=model, preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n' print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}') f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}') print(f'pipeline2: {pipeline2(input=self.sentence)}')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -42,8 +43,12 @@ class WordSegmentationTest(unittest.TestCase):
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=self.model_id) task=Tasks.word_segmentation, model=self.model_id)
print(pipeline_ins(input=self.sentence))
print(pipeline_ins(input=self.sentence_eng))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_zh'):
print(pipeline_ins(input=self.sentence))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_ws_en'):
print(pipeline_ins(input=self.sentence_eng))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):


+ 7
- 2
tests/pipelines/test_zero_shot_classification.py View File

@@ -8,6 +8,7 @@ from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import ZeroShotClassificationPipeline from modelscope.pipelines.nlp import ZeroShotClassificationPipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




@@ -16,6 +17,7 @@ class ZeroShotClassificationTest(unittest.TestCase):
sentence = '全新突破 解放军运20版空中加油机曝光' sentence = '全新突破 解放军运20版空中加油机曝光'
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
template = '这篇文章的标题是{}' template = '这篇文章的标题是{}'
regress_tool = MsRegressTool(baseline=False)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self): def test_run_with_direct_file_download(self):
@@ -33,7 +35,6 @@ class ZeroShotClassificationTest(unittest.TestCase):
f'sentence: {self.sentence}\n' f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
) )
print()
print( print(
f'sentence: {self.sentence}\n' f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
@@ -53,7 +54,11 @@ class ZeroShotClassificationTest(unittest.TestCase):
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, model=self.model_id) task=Tasks.zero_shot_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
with self.regress_tool.monitor_module_single_forward(
pipeline_ins.model, 'sbert_zero_shot'):
print(
pipeline_ins(
input=self.sentence, candidate_labels=self.labels))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):


+ 1
- 0
tests/run.py View File

@@ -334,6 +334,7 @@ if __name__ == '__main__':
help='Save result to directory, internal use only') help='Save result to directory, internal use only')
args = parser.parse_args() args = parser.parse_args()
set_test_level(args.level) set_test_level(args.level)
os.environ['REGRESSION_BASELINE'] = '1'
logger.info(f'TEST LEVEL: {test_level()}') logger.info(f'TEST LEVEL: {test_level()}')
if not args.disable_profile: if not args.disable_profile:
from utils import profiler from utils import profiler


Loading…
Cancel
Save