Browse Source

Merge branch 'master-gitlab' into merge_master_internal_1026

master
wenmeng.zwm 2 years ago
parent
commit
7e8e60f3b3
29 changed files with 540 additions and 177 deletions
  1. +8
    -0
      modelscope/hub/api.py
  2. +1
    -0
      modelscope/metainfo.py
  3. +14
    -8
      modelscope/metrics/accuracy_metric.py
  4. +87
    -0
      modelscope/metrics/ned_metric.py
  5. +14
    -3
      modelscope/metrics/text_generation_metric.py
  6. +2
    -1
      modelscope/models/nlp/__init__.py
  7. +19
    -0
      modelscope/models/nlp/bloom/__init__.py
  8. +2
    -2
      modelscope/models/nlp/bloom/backbone.py
  9. +2
    -0
      modelscope/models/nlp/gpt3/backbone.py
  10. +2
    -6
      modelscope/models/nlp/gpt3/text_generation.py
  11. +2
    -2
      modelscope/models/nlp/palm_v2/backbone.py
  12. +3
    -38
      modelscope/models/nlp/palm_v2/text_generation.py
  13. +12
    -0
      modelscope/msdatasets/ms_dataset.py
  14. +0
    -4
      modelscope/pipelines/nlp/faq_question_answering_pipeline.py
  15. +38
    -7
      modelscope/pipelines/nlp/text_generation_pipeline.py
  16. +143
    -14
      modelscope/preprocessors/base.py
  17. +20
    -2
      modelscope/preprocessors/ofa/ocr_recognition.py
  18. +1
    -3
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  19. +15
    -11
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  20. +3
    -1
      modelscope/trainers/nlp/__init__.py
  21. +36
    -0
      modelscope/trainers/nlp/text_generation_trainer.py
  22. +24
    -2
      modelscope/trainers/trainer.py
  23. +8
    -28
      modelscope/trainers/utils/inference.py
  24. +13
    -2
      tests/hub/test_hub_upload.py
  25. +12
    -0
      tests/pipelines/test_text_generation.py
  26. +10
    -11
      tests/trainers/test_finetune_mplug.py
  27. +12
    -6
      tests/trainers/test_finetune_text_generation.py
  28. +27
    -25
      tests/trainers/test_ofa_trainer.py
  29. +10
    -1
      tests/trainers/utils/test_inference.py

+ 8
- 0
modelscope/hub/api.py View File

@@ -266,6 +266,14 @@ class HubApi:
logger.info('Create new branch %s' % revision)
git_wrapper.new_branch(tmp_dir, revision)
git_wrapper.checkout(tmp_dir, revision)
files_in_repo = os.listdir(tmp_dir)
for f in files_in_repo:
if f[0] != '.':
src = os.path.join(tmp_dir, f)
if os.path.isfile(src):
os.remove(src)
else:
shutil.rmtree(src, ignore_errors=True)
for f in files_to_save:
if f[0] != '.':
src = os.path.join(model_dir, f)


+ 1
- 0
modelscope/metainfo.py View File

@@ -313,6 +313,7 @@ class Trainers(object):
nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer'
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
text_generation_trainer = 'text-generation-trainer'

# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'


+ 14
- 8
modelscope/metrics/accuracy_metric.py View File

@@ -27,15 +27,21 @@ class AccuracyMetric(Metric):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'
for truth in ground_truths:
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
else:
self.preds.append(result)

def evaluate(self):
assert len(self.preds) == len(self.labels)


+ 87
- 0
modelscope/metrics/ned_metric.py View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
class NedMetric(Metric):
"""The ned metric computation class for classification classes.

This metric class calculates the levenshtein distance between sentences for the whole input batches.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []

def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise Exception('only support list or np.ndarray')

def evaluate(self):
assert len(self.preds) == len(self.labels)
return {
MetricKeys.NED: (np.asarray([
1.0 - NedMetric._distance(pred, ref)
for pred, ref in zip(self.preds, self.labels)
])).mean().item()
}

@staticmethod
def _distance(pred, ref):
if pred is None or ref is None:
raise TypeError('Argument (pred or ref) is NoneType.')
if pred == ref:
return 0.0
if len(pred) == 0:
return len(ref)
if len(ref) == 0:
return len(pred)
m_len = max(len(pred), len(ref))
if m_len == 0:
return 0.0

def levenshtein(s0, s1):
v0 = [0] * (len(s1) + 1)
v1 = [0] * (len(s1) + 1)

for i in range(len(v0)):
v0[i] = i

for i in range(len(s0)):
v1[0] = i + 1
for j in range(len(s1)):
cost = 1
if s0[i] == s1[j]:
cost = 0
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
v0, v1 = v1, v0
return v0[len(s1)]

return levenshtein(pred, ref) / m_len

+ 14
- 3
modelscope/metrics/text_generation_metric.py View File

@@ -36,20 +36,31 @@ class TextGenerationMetric(Metric):
for char in string
]).split())

def add(self, outputs: Dict[str, List[str]], inputs: Dict = None):
ground_truths = outputs['tgts']
def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]):
ground_truths = inputs['tgts']
eval_results = outputs['preds']
for truth in ground_truths:
self.tgts.append(self.rebuild_str(truth))
for result in eval_results:
self.preds.append(self.rebuild_str(result))

def _check(self, pred: str, tgt: str) -> bool:

def remove_useless(string: str) -> str:
return string.replace(' ', '').replace('.', '')

return remove_useless(pred) and remove_useless(tgt)

def evaluate(self):
assert self.preds, 'preds in TextGenerationMetric must not be empty!'
tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts)
if self._check(pred, tgt)]
preds, tgts = zip(*tmp)

def mean(iter: Iterable) -> float:
return sum(iter) / len(self.preds)

rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts)
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
pred_split = tuple(pred.split(' ') for pred in self.preds)


+ 2
- 1
modelscope/models/nlp/__init__.py View File

@@ -49,7 +49,7 @@ if TYPE_CHECKING:
VecoForSequenceClassification,
VecoForTokenClassification, VecoModel, VecoTokenizer,
VecoTokenizerFast)
from .bloom import BloomModel
else:
_import_structure = {
'backbones': ['SbertModel'],
@@ -107,6 +107,7 @@ else:
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],
'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
}

import sys


+ 19
- 0
modelscope/models/nlp/bloom/__init__.py View File

@@ -0,0 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .backbone import BloomModel
else:
_import_structure = {
'backbone': ['BloomModel'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 2
- 2
modelscope/models/nlp/bloom/backbone.py View File

@@ -4,10 +4,10 @@ from transformers import BloomModel as BloomModelTransform

from modelscope.metainfo import Models
from modelscope.models.builder import BACKBONES
from modelscope.utils.constant import Fields
from modelscope.utils.constant import Tasks


@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom)
@BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom)
class BloomModel(BloomModelTransform):

def __init__(self, **kwargs):


+ 2
- 0
modelscope/models/nlp/gpt3/backbone.py View File

@@ -342,6 +342,8 @@ class GPT3Model(PreTrainedModel):
state_dict_file = os.path.join(pretrained_model_name_or_path,
ModelFile.TORCH_MODEL_BIN_FILE)
state_dict = torch.load(state_dict_file)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
state_dict = {
k.replace('model.language_model', 'language_model'): v
for k, v in state_dict.items()


+ 2
- 6
modelscope/models/nlp/gpt3/text_generation.py View File

@@ -42,7 +42,7 @@ class GPT3ForTextGeneration(TorchModel):
"""
return self.model(**input)

def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
assert 'input_ids' in input, "generate function must accept 'input_ids' key"
input_ids = input['input_ids']
if 'attention_mask' in input:
@@ -59,8 +59,4 @@ class GPT3ForTextGeneration(TorchModel):
gen_params['top_k'] = input.pop('top_k', 10)
gen_params['top_p'] = input.pop('top_p', None)
sample_output = self.model.generate(**gen_params)
return {
OutputKeys.TEXT:
self.tokenizer.decode(sample_output[0],
skip_special_tokens=True).replace(' ', '')
}
return {'sequences': sample_output[0]}

+ 2
- 2
modelscope/models/nlp/palm_v2/backbone.py View File

@@ -1314,8 +1314,8 @@ class Translator(object):

return results

def __call__(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
**kwargs) -> Dict[str, torch.Tensor]:
batch = self.Batch(
batch_size=input_ids.size()[0],
src=input_ids,


+ 3
- 38
modelscope/models/nlp/palm_v2/text_generation.py View File

@@ -29,22 +29,6 @@ class PalmForTextGeneration(TorchModel):
self.tokenizer = self.model.tokenizer
self.generator = Translator(self.model)

def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]',
''),
(r' +', ' '), ('[SEP]', ''), ('[unused2]', ''),
('[CLS]', ''), ('[UNK]', ''), (' ', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '),
('<pad>', ''), ('<s>', ''), ('</s>', ''),
('<unk>', ' '), ('<q>', '. '))

replace_tokens = replace_tokens_roberta \
if self.model.config.encoder == 'roberta' else replace_tokens_bert
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list]
for _old, _new in replace_tokens:
strings = [s.replace(_old, _new) for s in strings]
return strings

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

@@ -57,29 +41,10 @@ class PalmForTextGeneration(TorchModel):
{
'loss': Tensor([12.34]), # loss for backward
}
or
{
'preds': List["hello word"...] # the predicted strings
'tgts': List["hello world"...] # target strings
}
"""
if self.training:
return self.model(**input)
else:
outputs = self.generator(input['input_ids'],
input['attention_mask'])
preds = outputs['predictions']
pred_ids_list = [
pred_batch[0].cpu().numpy().tolist() for pred_batch in preds
]
tgt_ids_list = input['labels'].cpu().numpy().tolist()
return {
'preds': self._evaluate_postprocess(pred_ids_list),
'tgts': self._evaluate_postprocess(tgt_ids_list)
}
return self.model(**input)

def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = self.generator(**input)
preds = outputs['predictions']
pred_ids_list = [preds[0][0].cpu().numpy().tolist()]
return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]}
return {'sequences': [pred[0] for pred in preds]}

+ 12
- 0
modelscope/msdatasets/ms_dataset.py View File

@@ -563,6 +563,18 @@ class MsDataset:
self._hf_ds.reset_format()
return self._hf_ds

def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset:
"""
Rename columns and return the underlying hf dataset directly
TODO: support native MsDataset column rename.
Args:
column_mapping: the mapping of the original and new column names
Returns:
underlying hf dataset
"""
self._hf_ds.reset_format()
return self._hf_ds.rename_columns(column_mapping)

@staticmethod
def upload(object_name: str,
local_file_path: str,


+ 0
- 4
modelscope/pipelines/nlp/faq_question_answering_pipeline.py View File

@@ -26,10 +26,6 @@ class FaqQuestionAnsweringPipeline(Pipeline):
if preprocessor is None:
preprocessor = Preprocessor.from_pretrained(
model.model_dir, **kwargs)
if preprocessor is None:
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor
preprocessor = FaqQuestionAnsweringPreprocessor(
model.model_dir, **kwargs)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **pipeline_parameters):


+ 38
- 7
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -53,7 +53,7 @@ class TextGenerationPipeline(Pipeline):
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
cfg = read_config(model.model_dir)
self.postprocessor = cfg.pop('postprocessor', None)
self.postprocessor = cfg.pop('postprocessor', 'decode')
if preprocessor is None:
preprocessor_cfg = cfg.preprocessor
preprocessor_cfg.update({
@@ -78,8 +78,37 @@ class TextGenerationPipeline(Pipeline):
with torch.no_grad():
return self.model.generate(inputs, **forward_params)

def sentence_piece(self, inputs) -> Dict[str, Tensor]:
return self.preprocessor.tokenizer.decode(inputs.tolist()[0])
def _is_chinese_char(self, word: str):
chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》')
return len(word) == 1 \
and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations)

def _remove_space_between_chinese_chars(self, decoded: str):
old_word_list = decoded.split(' ')
new_word_list = []
start = -1
for i, word in enumerate(old_word_list):
if self._is_chinese_char(word):
if start == -1:
start = i
else:
if start != -1:
new_word_list.append(''.join(old_word_list[start:i]))
start = -1
new_word_list.append(word)
if start != -1:
new_word_list.append(''.join(old_word_list[start:]))
return ' '.join(new_word_list)

def decode(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True)

def roberta(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
decoded = tokenizer.decode(inputs.tolist())
return decoded.replace('<q>', '. ').replace('<mask>',
'. ').replace('</s>', '')

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
@@ -91,7 +120,9 @@ class TextGenerationPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
return inputs if self.postprocessor is None else {
OutputKeys.TEXT:
getattr(self, self.postprocessor.replace('-', '_'))(inputs)
}
inputs = inputs['sequences']
if isinstance(inputs, list):
inputs = inputs[0]
decoded = getattr(self, self.postprocessor)(inputs)
text = self._remove_space_between_chinese_chars(decoded)
return {OutputKeys.TEXT: text}

+ 143
- 14
modelscope/preprocessors/base.py View File

@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Optional, Sequence

from modelscope.utils.config import Config
from modelscope.metainfo import Models, Preprocessors
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.logger import get_logger
@@ -12,6 +13,112 @@ from .builder import build_preprocessor

logger = get_logger(__name__)

PREPROCESSOR_MAP = {
# nlp
# bart
(Models.bart, Tasks.text_error_correction):
Preprocessors.text_error_correction,

# bert
(Models.bert, Tasks.backbone):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.document_segmentation):
Preprocessors.document_segmentation,
(Models.bert, Tasks.fill_mask):
Preprocessors.fill_mask,
(Models.bert, Tasks.sentence_embedding):
Preprocessors.sentence_embedding,
(Models.bert, Tasks.text_classification):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.nli):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.sentiment_classification):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.sentence_similarity):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.zero_shot_classification):
Preprocessors.sen_cls_tokenizer,
(Models.bert, Tasks.text_ranking):
Preprocessors.text_ranking,
(Models.bert, Tasks.part_of_speech):
Preprocessors.token_cls_tokenizer,
(Models.bert, Tasks.token_classification):
Preprocessors.token_cls_tokenizer,
(Models.bert, Tasks.word_segmentation):
Preprocessors.token_cls_tokenizer,

# bloom
(Models.bloom, Tasks.backbone):
Preprocessors.text_gen_tokenizer,

# gpt_neo
# gpt_neo may have different preprocessors, but now only one
(Models.gpt_neo, Tasks.backbone):
Preprocessors.sentence_piece,

# gpt3 has different preprocessors by different sizes of models, so they are not listed here.

# palm_v2
(Models.palm, Tasks.backbone):
Preprocessors.text_gen_tokenizer,

# T5
(Models.T5, Tasks.backbone):
Preprocessors.text2text_gen_preprocessor,
(Models.T5, Tasks.text2text_generation):
Preprocessors.text2text_gen_preprocessor,

# deberta_v2
(Models.deberta_v2, Tasks.backbone):
Preprocessors.sen_cls_tokenizer,
(Models.deberta_v2, Tasks.fill_mask):
Preprocessors.fill_mask,

# ponet
(Models.ponet, Tasks.fill_mask):
Preprocessors.fill_mask_ponet,

# structbert
(Models.structbert, Tasks.backbone):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.fill_mask):
Preprocessors.fill_mask,
(Models.structbert, Tasks.faq_question_answering):
Preprocessors.faq_question_answering_preprocessor,
(Models.structbert, Tasks.text_classification):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.nli):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.sentiment_classification):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.sentence_similarity):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.zero_shot_classification):
Preprocessors.sen_cls_tokenizer,
(Models.structbert, Tasks.part_of_speech):
Preprocessors.token_cls_tokenizer,
(Models.structbert, Tasks.token_classification):
Preprocessors.token_cls_tokenizer,
(Models.structbert, Tasks.word_segmentation):
Preprocessors.token_cls_tokenizer,

# veco
(Models.veco, Tasks.backbone):
Preprocessors.sen_cls_tokenizer,
(Models.veco, Tasks.fill_mask):
Preprocessors.fill_mask,
(Models.veco, Tasks.text_classification):
Preprocessors.sen_cls_tokenizer,
(Models.veco, Tasks.nli):
Preprocessors.sen_cls_tokenizer,
(Models.veco, Tasks.sentiment_classification):
Preprocessors.sen_cls_tokenizer,
(Models.veco, Tasks.sentence_similarity):
Preprocessors.sen_cls_tokenizer,

# space
}


class Preprocessor(ABC):

@@ -56,37 +163,59 @@ class Preprocessor(ABC):
if 'task' in kwargs:
task = kwargs.pop('task')
field_name = Tasks.find_field_by_task(task)
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'

if not hasattr(cfg, 'preprocessor'):
logger.error('No preprocessor field found in cfg.')
return None
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'
preprocessor_cfg = ConfigDict()
else:
preprocessor_cfg = cfg.preprocessor

if 'type' not in cfg.preprocessor:
if sub_key in cfg.preprocessor:
sub_cfg = getattr(cfg.preprocessor, sub_key)
if 'type' not in preprocessor_cfg:
if sub_key in preprocessor_cfg:
sub_cfg = getattr(preprocessor_cfg, sub_key)
else:
logger.error(
f'No {sub_key} key and type key found in '
f'preprocessor domain of configuration.json file.')
return None
sub_cfg = preprocessor_cfg
else:
sub_cfg = cfg.preprocessor
sub_cfg = preprocessor_cfg

if len(sub_cfg):
sub_cfg.update({'model_dir': model_dir})
sub_cfg.update(kwargs)
if 'type' in sub_cfg:
if isinstance(sub_cfg, Sequence):
# TODO: for Sequence, need adapt to `mode` and `mode_dir` args,
# and add mode for Compose or other plans
raise NotImplementedError('Not supported yet!')
sub_cfg = deepcopy(sub_cfg)
sub_cfg.update({'model_dir': model_dir})
sub_cfg.update(kwargs)

preprocessor = build_preprocessor(sub_cfg, field_name)
else:
logger.error(
f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, '
f'please check the preprocessor field in the configuration.json file.'
f'current config: {sub_cfg}. trying to build by task and model information.'
)
return None
model_cfg = getattr(cfg, 'model', ConfigDict())
model_type = model_cfg.type if hasattr(
model_cfg, 'type') else getattr(model_cfg, 'model_type', None)
if task is None or model_type is None:
logger.error(
f'Find task: {task}, model type: {model_type}. '
f'Insufficient information to build preprocessor, skip building preprocessor'
)
return None
if (model_type, task) not in PREPROCESSOR_MAP:
logger.error(
f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
f'skip building preprocessor.')
return None

sub_cfg = ConfigDict({
'type': PREPROCESSOR_MAP[(model_type, task)],
**sub_cfg
})
preprocessor = build_preprocessor(sub_cfg, field_name)
preprocessor.mode = preprocessor_mode
return preprocessor

+ 20
- 2
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])
return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
inputs = self.tokenize_text(prompt)
@@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

+ 1
- 3
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -129,9 +129,7 @@ class OFATrainer(EpochBasedTrainer):

def train_step(self, model, inputs):
model.train()
model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(
model_outputs, inputs)
loss, sample_size, logging_output = self.criterion(model, inputs)
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:


+ 15
- 11
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
self.padding_idx = args.tokenizer.pad_token_id
self.args = args

def forward(self, output, sample, update_num=0, reduce=True):
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.

Returns a tuple with three elements:
@@ -131,11 +131,16 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if 'labels' in sample:
del sample['labels']
if 'samples' in sample:
del sample['samples']

if self.use_rdrop:
construct_rdrop_sample(sample)

output = model.model(**sample['net_input'])
loss, nll_loss, ntokens = self.compute_loss(
output, sample, update_num, reduce=reduce)
output.logits, sample, update_num, reduce=reduce)
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
}
return loss, sample_size, logging_output

def get_lprobs_and_target(self, net_output, sample):
def get_lprobs_and_target(self, logits, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
'conf'] is not None else 1
constraint_masks = None
if 'constraint_masks' in sample and sample[
'constraint_masks'] is not None:
constraint_masks = sample['constraint_masks']
net_output[0].masked_fill_(~constraint_masks, -math.inf)
logits.masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(
net_output[0], dim=-1, dtype=torch.float32) * conf
logits[:, :, 4:self.constraint_start] = -math.inf
logits[:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf
target = sample['target']
if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
return lprobs.view(-1,
lprobs.size(-1)), target.view(-1), constraint_masks

def compute_loss(self, net_output, sample, update_num, reduce=True):
def compute_loss(self, logits, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target(
net_output, sample)
logits, sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]


+ 3
- 1
modelscope/trainers/nlp/__init__.py View File

@@ -7,11 +7,13 @@ if TYPE_CHECKING:
from .sequence_classification_trainer import SequenceClassificationTrainer
from .csanmt_translation_trainer import CsanmtTranslationTrainer
from .text_ranking_trainer import TextRankingTrainer
from .text_generation_trainer import TextGenerationTrainer
else:
_import_structure = {
'sequence_classification_trainer': ['SequenceClassificationTrainer'],
'csanmt_translation_trainer': ['CsanmtTranslationTrainer'],
'text_ranking_trainer': ['TextRankingTrainer']
'text_ranking_trainer': ['TextRankingTrainer'],
'text_generation_trainer': ['TextGenerationTrainer'],
}

import sys


+ 36
- 0
modelscope/trainers/nlp/text_generation_trainer.py View File

@@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from collections.abc import Mapping

import torch

from modelscope.metainfo import Trainers
from modelscope.trainers import NlpEpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.file_utils import func_receive_dict_inputs


@TRAINERS.register_module(module_name=Trainers.text_generation_trainer)
class TextGenerationTrainer(NlpEpochBasedTrainer):

def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True)

def evaluation_step(self, data):
model = self.model
model.eval()

with torch.no_grad():
if isinstance(
data,
Mapping) and not func_receive_dict_inputs(model.generate):
result = model.generate(**data)
else:
result = model.generate(data)

result['preds'] = [self._decode(seq) for seq in result['sequences']]
data['tgts'] = [self._decode(seq) for seq in data['labels']]
assert len(result['preds']) == len(data['tgts'])

return result

+ 24
- 2
modelscope/trainers/trainer.py View File

@@ -855,6 +855,28 @@ class EpochBasedTrainer(BaseTrainer):

self.invoke_hook(TrainerStages.after_run)

def evaluation_step(self, data):
"""Perform a training step on a batch of inputs.

Subclass and override to inject custom behavior.

"""
model = self.model
model.eval()

if is_parallel(model):
receive_dict_inputs = func_receive_dict_inputs(
model.module.forward)
else:
receive_dict_inputs = func_receive_dict_inputs(model.forward)

with torch.no_grad():
if isinstance(data, Mapping) and not receive_dict_inputs:
result = model.forward(**data)
else:
result = model.forward(data)
return result

def evaluation_loop(self, data_loader, metric_classes):
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`.

@@ -862,7 +884,7 @@ class EpochBasedTrainer(BaseTrainer):
if self._dist:
from modelscope.trainers.utils.inference import multi_gpu_test
metric_values = multi_gpu_test(
self.model,
self,
data_loader,
device=self.device,
tmpdir=None,
@@ -872,7 +894,7 @@ class EpochBasedTrainer(BaseTrainer):
else:
from modelscope.trainers.utils.inference import single_gpu_test
metric_values = single_gpu_test(
self.model,
self,
data_loader,
device=self.device,
metric_classes=metric_classes,


+ 8
- 28
modelscope/trainers/utils/inference.py View File

@@ -4,29 +4,25 @@ import logging
import os
import pickle
import shutil
import time
from collections.abc import Mapping

import torch
from torch import distributed as dist
from tqdm import tqdm

from modelscope.trainers.parallel.utils import is_parallel
from modelscope.utils.data_utils import to_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master,
make_tmp_dir)


def single_gpu_test(model,
def single_gpu_test(trainer,
data_loader,
device,
metric_classes=None,
data_loader_iters=None):
"""Test model with a single gpu.
"""Test model in EpochBasedTrainer with a single gpu.

Args:
model (nn.Module): Model to be tested.
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
device (str | torch.device): The target device for the data.
metric_classes (List): List of Metric class that uses to collect metrics
@@ -35,7 +31,6 @@ def single_gpu_test(model,
Returns:
list: The prediction results.
"""
model.eval()
dataset = data_loader.dataset
progress_with_iters = False
if data_loader_iters is None:
@@ -55,12 +50,7 @@ def single_gpu_test(model,
with tqdm(total=data_len, desc=desc) as pbar:
for i, data in enumerate(data_loader):
data = to_device(data, device)
with torch.no_grad():
if isinstance(data, Mapping) and not func_receive_dict_inputs(
model.forward):
result = model.forward(**data)
else:
result = model.forward(data)
result = trainer.evaluation_step(data)
if metric_classes is not None:
for metric_cls in metric_classes:
metric_cls.add(result, data)
@@ -88,14 +78,14 @@ def single_gpu_test(model,
return metric_values


def multi_gpu_test(model,
def multi_gpu_test(trainer,
data_loader,
device,
tmpdir=None,
gpu_collect=False,
metric_classes=None,
data_loader_iters_per_gpu=None):
"""Test model with multiple gpus.
"""Test model in EpochBasedTrainer with multiple gpus.

This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting
@@ -104,7 +94,7 @@ def multi_gpu_test(model,
different gpus to ``tmpdir`` and collects them by the rank 0 worker.

Args:
model (nn.Module): Model to be tested.
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
device: (str | torch.device): The target device for the data.
tmpdir (str): Path of directory to save the temporary results from
@@ -115,7 +105,6 @@ def multi_gpu_test(model,
Returns:
list: The prediction results.
"""
model.eval()
results = []
data_list = []
dataset = data_loader.dataset
@@ -138,21 +127,12 @@ def multi_gpu_test(model,
data_len = data_loader_iters_per_gpu * world_size
desc = 'Total test iterations with multi gpus'

if is_parallel(model):
receive_dict_inputs = func_receive_dict_inputs(model.module.forward)
else:
receive_dict_inputs = func_receive_dict_inputs(model.forward)

count = 0
with tqdm(total=data_len, desc=desc) as pbar:
for i, data in enumerate(data_loader):
data = to_device(data, device)
data_list.append(data)
with torch.no_grad():
if isinstance(data, Mapping) and not receive_dict_inputs:
result = model.forward(**data)
else:
result = model.forward(data)
result = trainer.evaluation_step(data)
results.append(result)

if isinstance(data, dict):


+ 13
- 2
tests/hub/test_hub_upload.py View File

@@ -7,7 +7,7 @@ import uuid

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import HTTPError, NotLoginException
from modelscope.hub.errors import GitError, HTTPError, NotLoginException
from modelscope.hub.repository import Repository
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
@@ -97,6 +97,17 @@ class HubUploadTest(unittest.TestCase):
revision='new_revision/version1')
assert os.path.exists(os.path.join(add4_path, 'add4.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
assert os.path.exists(os.path.join(self.finetune_path, 'add3.py'))
os.remove(os.path.join(self.finetune_path, 'add3.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert not os.path.exists(os.path.join(self.repo_path, 'add3.py'))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_non_exists_repo(self):
@@ -133,7 +144,7 @@ class HubUploadTest(unittest.TestCase):
def test_upload_invalid_repo(self):
logger.info('test upload to invalid repo!')
self.api.login(TEST_ACCESS_TOKEN1)
with self.assertRaises(HTTPError):
with self.assertRaises((HTTPError, GitError)):
self.api.push_model(
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'),
model_dir=self.finetune_path,


+ 12
- 0
tests/pipelines/test_text_generation.py View File

@@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):

self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base'
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large'
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,'
self.gpt3_poetry_input = '天生我材必有用,'

def run_pipeline_with_model_instance(self, model_id, input):
model = Model.from_pretrained(model_id)
@@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_instance(self.palm_model_id_en,
self.palm_input_en)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_poetry_large_with_model_name(self):
self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id,
self.gpt3_poetry_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_base_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_base_model_id,
@@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_instance(self.gpt3_large_model_id,
self.gpt3_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_poetry_large_with_model_instance(self):
self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id,
self.gpt3_poetry_input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_palm(self):
for model_id, input in ((self.palm_model_id_zh_base,


+ 10
- 11
tests/trainers/test_finetune_mplug.py View File

@@ -24,17 +24,16 @@ class TestFinetuneMPlug(unittest.TestCase):
datadict = MsDataset.load(
'coco_captions_small_slice',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map(
lambda _: {
'question': 'what the picture describes?'
}).rename_column('image:FILE',
'image').rename_column('answer:Value', 'answer'))
self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map(
lambda _: {
'question': 'what the picture describes?'
}).rename_column('image:FILE',
'image').rename_column('answer:Value', 'answer'))

self.train_dataset = MsDataset(
datadict['train'].remap_columns({
'image:FILE': 'image',
'answer:Value': 'answer'
}).map(lambda _: {'question': 'what the picture describes?'}))
self.test_dataset = MsDataset(
datadict['test'].remap_columns({
'image:FILE': 'image',
'answer:Value': 'answer'
}).map(lambda _: {'question': 'what the picture describes?'}))
self.max_epochs = 2

def tearDown(self):


+ 12
- 6
tests/trainers/test_finetune_text_generation.py View File

@@ -59,7 +59,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
work_dir=self.tmp_dir)

trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@@ -98,7 +98,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
work_dir=self.tmp_dir)

trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@@ -130,10 +130,16 @@ class TestFinetuneTextGeneration(unittest.TestCase):
def test_finetune_cnndm(self):
from modelscope.msdatasets import MsDataset
dataset_dict = MsDataset.load('DuReader_robust-QG')
train_dataset = dataset_dict['train'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
eval_dataset = dataset_dict['validation'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
train_dataset = dataset_dict['train'].remap_columns({
'text1': 'src_txt',
'text2': 'tgt_txt'
})
eval_dataset = dataset_dict['validation'].remap_columns({
'text1':
'src_txt',
'text2':
'tgt_txt'
})
num_warmup_steps = 200
os.environ['LOCAL_RANK'] = '0'



+ 27
- 25
tests/trainers/test_ofa_trainer.py View File

@@ -5,10 +5,10 @@ import unittest

import json

from modelscope.metainfo import Metrics, Trainers
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level


@@ -17,26 +17,27 @@ class TestOfaTrainer(unittest.TestCase):
def setUp(self) -> None:
self.finetune_cfg = \
{'framework': 'pytorch',
'task': 'image-captioning',
'task': 'ocr-recognition',
'model': {'type': 'ofa',
'beam_search': {'beam_size': 5,
'max_len_b': 16,
'max_len_b': 64,
'min_len': 1,
'no_repeat_ngram_size': 0},
'seed': 7,
'max_src_length': 256,
'language': 'en',
'max_src_length': 128,
'language': 'zh',
'gen_type': 'generation',
'patch_image_size': 480,
'is_document': False,
'max_image_size': 480,
'imagenet_default_mean_and_std': False},
'pipeline': {'type': 'image-captioning'},
'dataset': {'column_map': {'text': 'caption'}},
'train': {'work_dir': 'work/ckpts/caption',
'pipeline': {'type': 'ofa-ocr-recognition'},
'dataset': {'column_map': {'text': 'label'}},
'train': {'work_dir': 'work/ckpts/recognition',
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
@@ -57,47 +58,48 @@ class TestOfaTrainer(unittest.TestCase):
'report_accuracy': False,
'sample_patch_num': 196,
'sentence_avg': False,
'use_rdrop': False},
'use_rdrop': True},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'bleu-4',
'metric_key': 'accuracy',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'metrics': [{'type': 'bleu',
'eval_tokenized_bleu': False,
'ref_name': 'labels',
'hyp_name': 'caption'}]},
'metrics': [{'type': 'accuracy'}]},
'preprocessor': []}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/caption'
WORKSPACE = './workspace/ckpts/recognition'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)

pretrained_model = 'damo/ofa_image-caption_coco_distilled_en'
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:20]'),
split='train[:200]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='validation[:10]'),
metrics=[Metrics.BLEU],
split='test[:20]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, 'output')))
self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
shutil.rmtree(WORKSPACE)




+ 10
- 1
tests/trainers/utils/test_inference.py View File

@@ -12,6 +12,7 @@ from modelscope.metrics.builder import MetricKeys
from modelscope.metrics.sequence_classification_metric import \
SequenceClassificationMetric
from modelscope.models.base import Model
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test
from modelscope.utils.test_utils import (DistributedTestCase,
create_dummy_test_dataset, test_level)
@@ -36,6 +37,12 @@ class DummyModel(nn.Module, Model):
return dict(logits=x, loss=loss)


class DummyTrainer(EpochBasedTrainer):

def __init__(self, model):
self.model = model


def test_func(dist=False):
dummy_model = DummyModel()
dataset = dummy_dataset.to_torch_dataset()
@@ -62,8 +69,10 @@ def test_func(dist=False):
else:
test_func = single_gpu_test

dummy_trainer = DummyTrainer(dummy_model)

metric_results = test_func(
dummy_model,
dummy_trainer,
dummy_loader,
device=device,
metric_classes=[metric_class])


Loading…
Cancel
Save