Browse Source

[to #42322933] Fix some bugs

1. Add F1 score to sequence classification metric
2. Fix a bug that the evaluate method in trainer does not support a pure pytorch_model.bin
3. Fix a bug in evaluation of veco trainer 
4. Add some tips if lr_scheduler in the trainer needs a higher version torch
5. Add some comments
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10532230
master
yuze.zyz yingda.chen 2 years ago
parent
commit
212cf53318
11 changed files with 153 additions and 18 deletions
  1. +8
    -1
      modelscope/metrics/sequence_classification_metric.py
  2. +21
    -1
      modelscope/models/base/base_model.py
  3. +2
    -2
      modelscope/models/science/unifold/modules/structure_module.py
  4. +43
    -1
      modelscope/preprocessors/base.py
  5. +5
    -4
      modelscope/trainers/hooks/checkpoint_hook.py
  6. +4
    -2
      modelscope/trainers/nlp_trainer.py
  7. +17
    -2
      modelscope/trainers/trainer.py
  8. +1
    -3
      modelscope/utils/checkpoint.py
  9. +32
    -0
      tests/metrics/test_text_classification_metrics.py
  10. +1
    -1
      tests/trainers/test_finetune_sequence_classification.py
  11. +19
    -1
      tests/trainers/test_trainer_with_nlp.py

+ 8
- 1
modelscope/metrics/sequence_classification_metric.py View File

@@ -3,6 +3,7 @@
from typing import Dict from typing import Dict


import numpy as np import numpy as np
from sklearn.metrics import accuracy_score, f1_score


from modelscope.metainfo import Metrics from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
@@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric):
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
return { return {
MetricKeys.ACCURACY: MetricKeys.ACCURACY:
(preds == labels).astype(np.float32).mean().item()
accuracy_score(labels, preds),
MetricKeys.F1:
f1_score(
labels,
preds,
average='micro' if any([label > 1
for label in labels]) else None),
} }

+ 21
- 1
modelscope/models/base/base_model.py View File

@@ -67,8 +67,28 @@ class Model(ABC):
cfg_dict: Config = None, cfg_dict: Config = None,
device: str = None, device: str = None,
**kwargs): **kwargs):
""" Instantiate a model from local directory or remote model repo. Note
"""Instantiate a model from local directory or remote model repo. Note
that when loading from remote, the model revision can be specified. that when loading from remote, the model revision can be specified.

Args:
model_name_or_path(str): A model dir or a model id to be loaded
revision(str, `optional`): The revision used when the model_name_or_path is
a model id of the remote hub. default `master`.
cfg_dict(Config, `optional`): An optional model config. If provided, it will replace
the config read out of the `model_name_or_path`
device(str, `optional`): The device to load the model.
**kwargs:
task(str, `optional`): The `Tasks` enumeration value to replace the task value
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
equal to the model saved.
For example, load a `backbone` into a `text-classification` model.
Other kwargs will be directly fed into the `model` key, to replace the default configs.
Returns:
A model instance.

Examples:
>>> from modelscope.models import Model
>>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification')
""" """
prefetched = kwargs.get('model_prefetched') prefetched = kwargs.get('model_prefetched')
if prefetched is not None: if prefetched is not None:


+ 2
- 2
modelscope/models/science/unifold/modules/structure_module.py View File

@@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module):
pt_att *= pt_att pt_att *= pt_att


pt_att = pt_att.sum(dim=-1) pt_att = pt_att.sum(dim=-1)
head_weights = self.softplus(self.head_weights).view(
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1)))
head_weights = self.softplus(self.head_weights).view( # noqa
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa
head_weights = head_weights * math.sqrt( head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.num_qk_points * 9.0 / 2))) 1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
pt_att *= head_weights * (-0.5) pt_att *= head_weights * (-0.5)


+ 43
- 1
modelscope/preprocessors/base.py View File

@@ -147,8 +147,50 @@ class Preprocessor(ABC):
cfg_dict: Config = None, cfg_dict: Config = None,
preprocessor_mode=ModeKeys.INFERENCE, preprocessor_mode=ModeKeys.INFERENCE,
**kwargs): **kwargs):
""" Instantiate a model from local directory or remote model repo. Note
"""Instantiate a preprocessor from local directory or remote model repo. Note
that when loading from remote, the model revision can be specified. that when loading from remote, the model revision can be specified.

Args:
model_name_or_path(str): A model dir or a model id used to load the preprocessor out.
revision(str, `optional`): The revision used when the model_name_or_path is
a model id of the remote hub. default `master`.
cfg_dict(Config, `optional`): An optional config. If provided, it will replace
the config read out of the `model_name_or_path`
preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`,
or `inference`. Default value `inference`.
The preprocessor field in the config may contain two sub preprocessors:
>>> {
>>> "train": {
>>> "type": "some-train-preprocessor"
>>> },
>>> "val": {
>>> "type": "some-eval-preprocessor"
>>> }
>>> }
In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor
will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class
will be assigned in all the modes.
Or just one:
>>> {
>>> "type": "some-train-preprocessor"
>>> }
In this scenario, the sole preprocessor will be loaded in all the modes,
and the `mode` field in the preprocessor class will be assigned.

**kwargs:
task(str, `optional`): The `Tasks` enumeration value to replace the task value
read out of config in the `model_name_or_path`.
This is useful when the preprocessor does not have a `type` field and the task to be used is not
equal to the task of which the model is saved.
Other kwargs will be directly fed into the preprocessor, to replace the default configs.

Returns:
The preprocessor instance.

Examples:
>>> from modelscope.preprocessors import Preprocessor
>>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base')

""" """
if not os.path.exists(model_name_or_path): if not os.path.exists(model_name_or_path):
model_dir = snapshot_download( model_dir = snapshot_download(


+ 5
- 4
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -101,8 +101,9 @@ class CheckpointHook(Hook):
model = trainer.model.module model = trainer.model.module
else: else:
model = trainer.model model = trainer.model
meta = load_checkpoint(filename, model, trainer.optimizer,
trainer.lr_scheduler)
meta = load_checkpoint(filename, model,
getattr(trainer, 'optimizer', None),
getattr(trainer, 'lr_scheduler', None))
trainer._epoch = meta.get('epoch', trainer._epoch) trainer._epoch = meta.get('epoch', trainer._epoch)
trainer._iter = meta.get('iter', trainer._iter) trainer._iter = meta.get('iter', trainer._iter)
trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter)
@@ -111,7 +112,7 @@ class CheckpointHook(Hook):
# hook: Hook # hook: Hook
key = f'{hook.__class__}-{i}' key = f'{hook.__class__}-{i}'
if key in meta and hasattr(hook, 'load_state_dict'): if key in meta and hasattr(hook, 'load_state_dict'):
hook.load_state_dict(meta[key])
hook.load_state_dict(meta.get(key, {}))
else: else:
trainer.logger.warn( trainer.logger.warn(
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.'
@@ -123,7 +124,7 @@ class CheckpointHook(Hook):
f'The modelscope version of loaded checkpoint does not match the runtime version. ' f'The modelscope version of loaded checkpoint does not match the runtime version. '
f'The saved version: {version}, runtime version: {__version__}' f'The saved version: {version}, runtime version: {__version__}'
) )
trainer.logger.warn(
trainer.logger.info(
f'Checkpoint {filename} saving time: {meta.get("time")}') f'Checkpoint {filename} saving time: {meta.get("time")}')
return meta return meta




+ 4
- 2
modelscope/trainers/nlp_trainer.py View File

@@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer):
break break


for metric_name in self.metrics: for metric_name in self.metrics:
metric_values[metric_name] = np.average(
[m[metric_name] for m in metric_values.values()])
all_metrics = [m[metric_name] for m in metric_values.values()]
for key in all_metrics[0].keys():
metric_values[key] = np.average(
[metric[key] for metric in all_metrics])


return metric_values return metric_values

+ 17
- 2
modelscope/trainers/trainer.py View File

@@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer):
return dataset return dataset


def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
return build_optimizer(self.model, cfg=cfg, default_args=default_args)
try:
return build_optimizer(
self.model, cfg=cfg, default_args=default_args)
except KeyError as e:
self.logger.error(
f'Build optimizer error, the optimizer {cfg} is native torch optimizer, '
f'please check if your torch with version: {torch.__version__} matches the config.'
)
raise e


def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None):
return build_lr_scheduler(cfg=cfg, default_args=default_args)
try:
return build_lr_scheduler(cfg=cfg, default_args=default_args)
except KeyError as e:
self.logger.error(
f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, '
f'please check if your torch with version: {torch.__version__} matches the config.'
)
raise e


def create_optimizer_and_scheduler(self): def create_optimizer_and_scheduler(self):
""" Create optimizer and lr scheduler """ Create optimizer and lr scheduler


+ 1
- 3
modelscope/utils/checkpoint.py View File

@@ -134,9 +134,7 @@ def load_checkpoint(filename,
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
'state_dict'] 'state_dict']
model.load_state_dict(state_dict) model.load_state_dict(state_dict)

if 'meta' in checkpoint:
return checkpoint.get('meta', {})
return checkpoint.get('meta', {})




def save_pretrained(model, def save_pretrained(model,


+ 32
- 0
tests/metrics/test_text_classification_metrics.py View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

import numpy as np

from modelscope.metrics.sequence_classification_metric import \
SequenceClassificationMetric
from modelscope.utils.test_utils import test_level


class TestTextClsMetrics(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_value(self):
metric = SequenceClassificationMetric()
outputs = {
'logits':
np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0],
[2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7],
[2.0, 1.0, 0.5], [2.4, 1.5, 0.5]])
}
inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])}
metric.add(outputs, inputs)
ret = metric.evaluate()
self.assertTrue(np.isclose(ret['f1'], 0.5))
self.assertTrue(np.isclose(ret['accuracy'], 0.5))
print(ret)


if __name__ == '__main__':
unittest.main()

+ 1
- 1
tests/trainers/test_finetune_sequence_classification.py View File

@@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
train_datasets = [] train_datasets = []
from datasets import DownloadConfig from datasets import DownloadConfig
dc = DownloadConfig() dc = DownloadConfig()
dc.local_files_only = True
dc.local_files_only = False
for lang in langs: for lang in langs:
train_datasets.append( train_datasets.append(
load_dataset('xnli', lang, split='train', download_config=dc)) load_dataset('xnli', lang, split='train', download_config=dc))


+ 19
- 1
tests/trainers/test_trainer_with_nlp.py View File

@@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase):
trainer, 'trainer_continue_train', level='strict'): trainer, 'trainer_continue_train', level='strict'):
trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth'))


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_evaluation(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
cache_path = snapshot_download(model_id)
model = SbertForSequenceClassification.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
eval_dataset=self.dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
print(trainer.evaluate(cache_path + '/pytorch_model.bin'))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self): def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)


model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
cache_path = snapshot_download(model_id) cache_path = snapshot_download(model_id)
model = SbertForSequenceClassification.from_pretrained(cache_path) model = SbertForSequenceClassification.from_pretrained(cache_path)
kwargs = dict( kwargs = dict(


Loading…
Cancel
Save