Browse Source

[to #42322933] Fix bug in UT daily

1. Fix bugs in daily test
2. Fix a bug that the updating of lr is before the first time of updating of optimizer
    TODO this will still cause warnings when GA is above 1
3. Remove the judgement of mode in text-classification's preprocessor to fit the base trainer(Bug)
     Update some regression bins to fit the preprocessor
4. Update the regression tool to let outer code modify atol and rtol
5. Add the default metric for text-classification task
6. Remove the useless ckpt conversion method in bert to avoid the requirement of tf when loading modeling_bert
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10430764
master
yuze.zyz 2 years ago
parent
commit
acba1786b0
13 changed files with 124 additions and 135 deletions
  1. +3
    -0
      data/test/regression/sbert-base-tnews.bin
  2. +2
    -2
      data/test/regression/sbert_nli.bin
  3. +2
    -2
      data/test/regression/sbert_sen_sim.bin
  4. +3
    -1
      modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py
  5. +1
    -0
      modelscope/metrics/builder.py
  6. +0
    -78
      modelscope/models/nlp/bert/modeling_bert.py
  7. +3
    -4
      modelscope/preprocessors/nlp/nlp_base.py
  8. +1
    -1
      modelscope/trainers/hooks/lr_scheduler_hook.py
  9. +1
    -1
      modelscope/trainers/trainer.py
  10. +58
    -36
      modelscope/utils/regress_test_utils.py
  11. +2
    -1
      tests/msdatasets/test_ms_dataset.py
  12. +31
    -2
      tests/trainers/test_finetune_sequence_classification.py
  13. +17
    -7
      tests/trainers/test_trainer_with_nlp.py

+ 3
- 0
data/test/regression/sbert-base-tnews.bin View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0
size 151572

+ 2
- 2
data/test/regression/sbert_nli.bin View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62
size 62231
oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b
size 61741

+ 2
- 2
data/test/regression/sbert_sen_sim.bin View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a
size 62235
oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41
size 61745

+ 3
- 1
modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py View File

@@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter):

def generate_dummy_inputs(self,
shape: Tuple = None,
pair: bool = False,
**kwargs) -> Dict[str, Any]:
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.

@param shape: A tuple of input shape which should have at most two dimensions.
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
@param pair: Generate sentence pairs or single sentences for dummy inputs.
@return: Dummy inputs.
"""

@@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter):
**sequence_length
})
preprocessor: Preprocessor = build_preprocessor(cfg, field_name)
if preprocessor.pair:
if pair:
first_sequence = preprocessor.tokenizer.unk_token
second_sequence = preprocessor.tokenizer.unk_token
else:


+ 1
- 0
modelscope/metrics/builder.py View File

@@ -32,6 +32,7 @@ task_default_metrics = {
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
Tasks.token_classification: [Metrics.token_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric],
Tasks.text_classification: [Metrics.seq_cls_metric],
Tasks.image_denoising: [Metrics.image_denoise_metric],
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric],
Tasks.image_portrait_enhancement:


+ 0
- 78
modelscope/models/nlp/bert/modeling_bert.py View File

@@ -15,7 +15,6 @@
"""PyTorch BERT model. """

import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel,
find_pruneable_heads_and_indices,
prune_linear_layer)

from modelscope.models.base import TorchModel
from modelscope.utils.logger import get_logger
from .configuration_bert import BertConfig

@@ -50,81 +48,6 @@ logger = get_logger(__name__)
_CONFIG_FOR_DOC = 'BertConfig'


def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re

import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
'https://www.tensorflow.org/install/ for installation instructions.'
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info(f'Loading TF weight {name} with shape {shape}')
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)

for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in [
'adam_v', 'adam_m', 'AdamWeightDecayOptimizer',
'AdamWeightDecayOptimizer_1', 'global_step'
] for n in name):
logger.info(f"Skipping {'/'.join(name)}")
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
scope_names = re.split(r'_(\d+)', m_name)
else:
scope_names = [m_name]
if scope_names[0] == 'kernel' or scope_names[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif scope_names[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
elif scope_names[0] == 'squad':
pointer = getattr(pointer, 'classifier')
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info(f"Skipping {'/'.join(name)}")
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
if pointer.shape != array.shape:
raise ValueError(
f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
)
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f'Initialize PyTorch weight {name}')
pointer.data = torch.from_numpy(array)
return model


class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

@@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel):
"""

config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']


+ 3
- 4
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -2,7 +2,7 @@

import os.path as osp
import re
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import sentencepiece as spm
@@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
return isinstance(label, str) or isinstance(label, int)

if labels is not None:
if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \
if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \
and self.label2id is not None:
output[OutputKeys.LABELS] = [
self.label2id[str(label)] for label in labels
@@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get(
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, mode=mode, **kwargs)



+ 1
- 1
modelscope/trainers/hooks/lr_scheduler_hook.py View File

@@ -47,7 +47,7 @@ class LrSchedulerHook(Hook):
return lr

def before_train_iter(self, trainer):
if not self.by_epoch:
if not self.by_epoch and trainer.iter > 0:
if self.warmup_lr_scheduler is not None:
self.warmup_lr_scheduler.step()
else:


+ 1
- 1
modelscope/trainers/trainer.py View File

@@ -651,7 +651,7 @@ class EpochBasedTrainer(BaseTrainer):
# TODO: support MsDataset load for cv
if hasattr(data_cfg, 'name'):
dataset = MsDataset.load(
dataset_name=data_cfg.name,
dataset_name=data_cfg.pop('name'),
**data_cfg,
)
cfg = ConfigDict(type=self.cfg.model.type, mode=mode)


+ 58
- 36
modelscope/utils/regress_test_utils.py View File

@@ -65,7 +65,8 @@ class RegressTool:
def monitor_module_single_forward(self,
module: nn.Module,
file_name: str,
compare_fn=None):
compare_fn=None,
**kwargs):
"""Monitor a pytorch module in a single forward.

@param module: A torch module
@@ -107,7 +108,7 @@ class RegressTool:
baseline = os.path.join(tempfile.gettempdir(), name)
self.load(baseline, name)
with open(baseline, 'rb') as f:
baseline_json = pickle.load(f)
base = pickle.load(f)

class NumpyEncoder(json.JSONEncoder):
"""Special json encoder for numpy types
@@ -122,9 +123,9 @@ class RegressTool:
return obj.tolist()
return json.JSONEncoder.default(self, obj)

print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}')
print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}')
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}')
if not compare_io_and_print(baseline_json, io_json, compare_fn):
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
raise ValueError('Result not match!')

@contextlib.contextmanager
@@ -136,7 +137,8 @@ class RegressTool:
ignore_keys=None,
compare_random=True,
reset_dropout=True,
lazy_stop_callback=None):
lazy_stop_callback=None,
**kwargs):
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.

This is usually useful when you try to change some dangerous code
@@ -265,14 +267,15 @@ class RegressTool:
baseline_json = pickle.load(f)

if level == 'strict' and not compare_io_and_print(
baseline_json['forward'], io_json, compare_fn):
baseline_json['forward'], io_json, compare_fn, **kwargs):
raise RuntimeError('Forward not match!')
if not compare_backward_and_print(
baseline_json['backward'],
bw_json,
compare_fn=compare_fn,
ignore_keys=ignore_keys,
level=level):
level=level,
**kwargs):
raise RuntimeError('Backward not match!')
cfg_opt1 = {
'optimizer': baseline_json['optimizer'],
@@ -286,7 +289,8 @@ class RegressTool:
'cfg': summary['cfg'],
'state': None if not compare_random else summary['state']
}
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn):
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
**kwargs):
raise RuntimeError('Cfg or optimizers not match!')


@@ -303,7 +307,8 @@ class MsRegressTool(RegressTool):
compare_fn=None,
ignore_keys=None,
compare_random=True,
lazy_stop_callback=None):
lazy_stop_callback=None,
**kwargs):

if lazy_stop_callback is None:

@@ -319,7 +324,7 @@ class MsRegressTool(RegressTool):

trainer.register_hook(EarlyStopHook())

def _train_loop(trainer, *args, **kwargs):
def _train_loop(trainer, *args_train, **kwargs_train):
with self.monitor_module_train(
trainer,
file_name,
@@ -327,9 +332,11 @@ class MsRegressTool(RegressTool):
compare_fn=compare_fn,
ignore_keys=ignore_keys,
compare_random=compare_random,
lazy_stop_callback=lazy_stop_callback):
lazy_stop_callback=lazy_stop_callback,
**kwargs):
try:
return trainer.train_loop_origin(*args, **kwargs)
return trainer.train_loop_origin(*args_train,
**kwargs_train)
except MsRegressTool.EarlyStopError:
pass

@@ -530,7 +537,8 @@ def compare_arguments_nested(print_content,
)
return False
if not all([
compare_arguments_nested(None, sub_arg1, sub_arg2)
compare_arguments_nested(
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
for sub_arg1, sub_arg2 in zip(arg1, arg2)
]):
if print_content is not None:
@@ -551,7 +559,8 @@ def compare_arguments_nested(print_content,
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
return False
if not all([
compare_arguments_nested(None, arg1[key], arg2[key])
compare_arguments_nested(
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
for key in keys1
]):
if print_content is not None:
@@ -574,7 +583,7 @@ def compare_arguments_nested(print_content,
raise ValueError(f'type not supported: {type1}')


def compare_io_and_print(baseline_json, io_json, compare_fn=None):
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
if compare_fn is None:

def compare_fn(*args, **kwargs):
@@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None):
else:
match = compare_arguments_nested(
f'unmatched module {key} input args', v1input['args'],
v2input['args']) and match
v2input['args'], **kwargs) and match
match = compare_arguments_nested(
f'unmatched module {key} input kwargs', v1input['kwargs'],
v2input['kwargs']) and match
v2input['kwargs'], **kwargs) and match
v1output = numpify_tensor_nested(v1['output'])
v2output = numpify_tensor_nested(v2['output'])
res = compare_fn(v1output, v2output, key, 'output')
@@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None):
)
match = match and res
else:
match = compare_arguments_nested(f'unmatched module {key} outputs',
v1output, v2output) and match
match = compare_arguments_nested(
f'unmatched module {key} outputs',
arg1=v1output,
arg2=v2output,
**kwargs) and match
return match


@@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json,
bw_json,
level,
ignore_keys=None,
compare_fn=None):
compare_fn=None,
**kwargs):
if compare_fn is None:

def compare_fn(*args, **kwargs):
@@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json,
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
'grad'], bw_json[key]['data_after']
match = compare_arguments_nested(
f'unmatched module {key} tensor data', data1, data2) and match
f'unmatched module {key} tensor data',
arg1=data1,
arg2=data2,
**kwargs) and match
if level == 'strict':
match = compare_arguments_nested(
f'unmatched module {key} grad data', grad1,
grad2) and match
f'unmatched module {key} grad data',
arg1=grad1,
arg2=grad2,
**kwargs) and match
match = compare_arguments_nested(
f'unmatched module {key} data after step', data_after1,
data_after2) and match
data_after2, **kwargs) and match
return match


def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
def compare_cfg_and_optimizers(baseline_json,
cfg_json,
compare_fn=None,
**kwargs):
if compare_fn is None:

def compare_fn(*args, **kwargs):
@@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
print(
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
)
match = compare_arguments_nested('unmatched optimizer defaults',
optimizer1['defaults'],
optimizer2['defaults']) and match
match = compare_arguments_nested('unmatched optimizer state_dict',
optimizer1['state_dict'],
optimizer2['state_dict']) and match
match = compare_arguments_nested(
'unmatched optimizer defaults', optimizer1['defaults'],
optimizer2['defaults'], **kwargs) and match
match = compare_arguments_nested(
'unmatched optimizer state_dict', optimizer1['state_dict'],
optimizer2['state_dict'], **kwargs) and match

res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
if res is not None:
@@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
print(
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
)
match = compare_arguments_nested('unmatched lr_scheduler state_dict',
lr_scheduler1['state_dict'],
lr_scheduler2['state_dict']) and match
match = compare_arguments_nested(
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
lr_scheduler2['state_dict'], **kwargs) and match

res = compare_fn(cfg1, cfg2, None, 'cfg')
if res is not None:
print(f'cfg compared with user compare_fn with result:{res}\n')
match = match and res
else:
match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match
match = compare_arguments_nested(
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match

res = compare_fn(state1, state2, None, 'state')
if res is not None:
@@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
match = match and res
else:
match = compare_arguments_nested('unmatched random state', state1,
state2) and match
state2, **kwargs) and match

return match

+ 2
- 1
tests/msdatasets/test_ms_dataset.py View File

@@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ms_csv_basic(self):
ms_ds_train = MsDataset.load(
'afqmc_small', namespace='userxiaoming', split='train')
'clue', subset_name='afqmc',
split='train').to_hf_dataset().select(range(5))
print(next(iter(ms_ds_train)))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


+ 31
- 2
tests/trainers/test_finetune_sequence_classification.py View File

@@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
calculate_fisher
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.data_utils import to_device
from modelscope.utils.regress_test_utils import MsRegressTool
from modelscope.utils.regress_test_utils import (MsRegressTool,
compare_arguments_nested)
from modelscope.utils.test_utils import test_level


@@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
def test_trainer_repeatable(self):
import torch # noqa

def compare_fn(value1, value2, key, type):
# Ignore the differences between optimizers of two torch versions
if type != 'optimizer':
return None

match = (value1['type'] == value2['type'])
shared_defaults = set(value1['defaults'].keys()).intersection(
set(value2['defaults'].keys()))
match = all([
compare_arguments_nested(f'Optimizer defaults {key} not match',
value1['defaults'][key],
value2['defaults'][key])
for key in shared_defaults
]) and match
match = (len(value1['state_dict']['param_groups']) == len(
value2['state_dict']['param_groups'])) and match
for group1, group2 in zip(value1['state_dict']['param_groups'],
value2['state_dict']['param_groups']):
shared_keys = set(group1.keys()).intersection(
set(group2.keys()))
match = all([
compare_arguments_nested(
f'Optimizer param_groups {key} not match', group1[key],
group2[key]) for key in shared_keys
]) and match
return match

def cfg_modify_fn(cfg):
cfg.task = 'nli'
cfg['preprocessor'] = {'type': 'nli-tokenizer'}
@@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
name=Trainers.nlp_base_trainer, default_args=kwargs)

with self.regress_tool.monitor_ms_train(
trainer, 'sbert-base-tnews', level='strict'):
trainer, 'sbert-base-tnews', level='strict',
compare_fn=compare_fn):
trainer.train()

def finetune(self,


+ 17
- 7
tests/trainers/test_trainer_with_nlp.py View File

@@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase):
os.makedirs(self.tmp_dir)

self.dataset = MsDataset.load(
'afqmc_small', namespace='userxiaoming', split='train')
'clue', subset_name='afqmc',
split='train').to_hf_dataset().select(range(2))

def tearDown(self):
shutil.rmtree(self.tmp_dir)
@@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase):
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
pipeline_sentence_similarity(output_dir)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
def test_trainer_with_backbone_head(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
kwargs = dict(
@@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
cfg = read_config(model_id, revision='beta')
cfg.train.max_epochs = 20
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
cfg.train.work_dir = self.tmp_dir
cfg_file = os.path.join(self.tmp_dir, 'config.json')
cfg.dump(cfg_file)
@@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase):
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth'))
self.assertTrue(Metrics.accuracy in eval_results)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_configured_datasets(self):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
cfg: Config = read_config(model_id)
cfg.train.max_epochs = 20
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
cfg.train.work_dir = self.tmp_dir
cfg.dataset = {
'train': {
'name': 'afqmc_small',
'name': 'clue',
'subset_name': 'afqmc',
'split': 'train',
'namespace': 'userxiaoming'
},
'val': {
'name': 'afqmc_small',
'name': 'clue',
'subset_name': 'afqmc',
'split': 'train',
'namespace': 'userxiaoming'
},
}
cfg_file = os.path.join(self.tmp_dir, 'config.json')
@@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
cfg: Config = read_config(model_id)
cfg.train.max_epochs = 3
cfg.preprocessor.first_sequence = 'sentence1'
cfg.preprocessor.second_sequence = 'sentence2'
cfg.preprocessor.label = 'label'
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
cfg.train.work_dir = self.tmp_dir
cfg_file = os.path.join(self.tmp_dir, 'config.json')
cfg.dump(cfg_file)


Loading…
Cancel
Save