@@ -266,6 +266,14 @@ class HubApi: | |||||
logger.info('Create new branch %s' % revision) | logger.info('Create new branch %s' % revision) | ||||
git_wrapper.new_branch(tmp_dir, revision) | git_wrapper.new_branch(tmp_dir, revision) | ||||
git_wrapper.checkout(tmp_dir, revision) | git_wrapper.checkout(tmp_dir, revision) | ||||
files_in_repo = os.listdir(tmp_dir) | |||||
for f in files_in_repo: | |||||
if f[0] != '.': | |||||
src = os.path.join(tmp_dir, f) | |||||
if os.path.isfile(src): | |||||
os.remove(src) | |||||
else: | |||||
shutil.rmtree(src, ignore_errors=True) | |||||
for f in files_to_save: | for f in files_to_save: | ||||
if f[0] != '.': | if f[0] != '.': | ||||
src = os.path.join(model_dir, f) | src = os.path.join(model_dir, f) | ||||
@@ -313,6 +313,7 @@ class Trainers(object): | |||||
nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | ||||
text_generation_trainer = 'text-generation-trainer' | |||||
# audio trainers | # audio trainers | ||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
@@ -27,15 +27,21 @@ class AccuracyMetric(Metric): | |||||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | ||||
ground_truths = inputs[label_name] | ground_truths = inputs[label_name] | ||||
eval_results = outputs[label_name] | eval_results = outputs[label_name] | ||||
for key in [ | |||||
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||||
OutputKeys.LABELS, OutputKeys.SCORES | |||||
]: | |||||
if key in outputs and outputs[key] is not None: | |||||
eval_results = outputs[key] | |||||
break | |||||
assert type(ground_truths) == type(eval_results) | assert type(ground_truths) == type(eval_results) | ||||
if isinstance(ground_truths, list): | |||||
self.preds.extend(eval_results) | |||||
self.labels.extend(ground_truths) | |||||
elif isinstance(ground_truths, np.ndarray): | |||||
self.preds.extend(eval_results.tolist()) | |||||
self.labels.extend(ground_truths.tolist()) | |||||
else: | |||||
raise 'only support list or np.ndarray' | |||||
for truth in ground_truths: | |||||
self.labels.append(truth) | |||||
for result in eval_results: | |||||
if isinstance(truth, str): | |||||
self.preds.append(result.strip().replace(' ', '')) | |||||
else: | |||||
self.preds.append(result) | |||||
def evaluate(self): | def evaluate(self): | ||||
assert len(self.preds) == len(self.labels) | assert len(self.preds) == len(self.labels) | ||||
@@ -0,0 +1,87 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Dict | |||||
import numpy as np | |||||
from modelscope.metainfo import Metrics | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.utils.registry import default_group | |||||
from .base import Metric | |||||
from .builder import METRICS, MetricKeys | |||||
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED) | |||||
class NedMetric(Metric): | |||||
"""The ned metric computation class for classification classes. | |||||
This metric class calculates the levenshtein distance between sentences for the whole input batches. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.preds = [] | |||||
self.labels = [] | |||||
def add(self, outputs: Dict, inputs: Dict): | |||||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||||
ground_truths = inputs[label_name] | |||||
eval_results = outputs[label_name] | |||||
for key in [ | |||||
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||||
OutputKeys.LABELS, OutputKeys.SCORES | |||||
]: | |||||
if key in outputs and outputs[key] is not None: | |||||
eval_results = outputs[key] | |||||
break | |||||
assert type(ground_truths) == type(eval_results) | |||||
if isinstance(ground_truths, list): | |||||
self.preds.extend(eval_results) | |||||
self.labels.extend(ground_truths) | |||||
elif isinstance(ground_truths, np.ndarray): | |||||
self.preds.extend(eval_results.tolist()) | |||||
self.labels.extend(ground_truths.tolist()) | |||||
else: | |||||
raise Exception('only support list or np.ndarray') | |||||
def evaluate(self): | |||||
assert len(self.preds) == len(self.labels) | |||||
return { | |||||
MetricKeys.NED: (np.asarray([ | |||||
1.0 - NedMetric._distance(pred, ref) | |||||
for pred, ref in zip(self.preds, self.labels) | |||||
])).mean().item() | |||||
} | |||||
@staticmethod | |||||
def _distance(pred, ref): | |||||
if pred is None or ref is None: | |||||
raise TypeError('Argument (pred or ref) is NoneType.') | |||||
if pred == ref: | |||||
return 0.0 | |||||
if len(pred) == 0: | |||||
return len(ref) | |||||
if len(ref) == 0: | |||||
return len(pred) | |||||
m_len = max(len(pred), len(ref)) | |||||
if m_len == 0: | |||||
return 0.0 | |||||
def levenshtein(s0, s1): | |||||
v0 = [0] * (len(s1) + 1) | |||||
v1 = [0] * (len(s1) + 1) | |||||
for i in range(len(v0)): | |||||
v0[i] = i | |||||
for i in range(len(s0)): | |||||
v1[0] = i + 1 | |||||
for j in range(len(s1)): | |||||
cost = 1 | |||||
if s0[i] == s1[j]: | |||||
cost = 0 | |||||
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) | |||||
v0, v1 = v1, v0 | |||||
return v0[len(s1)] | |||||
return levenshtein(pred, ref) / m_len |
@@ -36,20 +36,31 @@ class TextGenerationMetric(Metric): | |||||
for char in string | for char in string | ||||
]).split()) | ]).split()) | ||||
def add(self, outputs: Dict[str, List[str]], inputs: Dict = None): | |||||
ground_truths = outputs['tgts'] | |||||
def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | |||||
ground_truths = inputs['tgts'] | |||||
eval_results = outputs['preds'] | eval_results = outputs['preds'] | ||||
for truth in ground_truths: | for truth in ground_truths: | ||||
self.tgts.append(self.rebuild_str(truth)) | self.tgts.append(self.rebuild_str(truth)) | ||||
for result in eval_results: | for result in eval_results: | ||||
self.preds.append(self.rebuild_str(result)) | self.preds.append(self.rebuild_str(result)) | ||||
def _check(self, pred: str, tgt: str) -> bool: | |||||
def remove_useless(string: str) -> str: | |||||
return string.replace(' ', '').replace('.', '') | |||||
return remove_useless(pred) and remove_useless(tgt) | |||||
def evaluate(self): | def evaluate(self): | ||||
assert self.preds, 'preds in TextGenerationMetric must not be empty!' | |||||
tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts) | |||||
if self._check(pred, tgt)] | |||||
preds, tgts = zip(*tmp) | |||||
def mean(iter: Iterable) -> float: | def mean(iter: Iterable) -> float: | ||||
return sum(iter) / len(self.preds) | return sum(iter) / len(self.preds) | ||||
rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts) | |||||
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) | |||||
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | ||||
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | ||||
pred_split = tuple(pred.split(' ') for pred in self.preds) | pred_split = tuple(pred.split(' ') for pred in self.preds) | ||||
@@ -49,7 +49,7 @@ if TYPE_CHECKING: | |||||
VecoForSequenceClassification, | VecoForSequenceClassification, | ||||
VecoForTokenClassification, VecoModel, VecoTokenizer, | VecoForTokenClassification, VecoModel, VecoTokenizer, | ||||
VecoTokenizerFast) | VecoTokenizerFast) | ||||
from .bloom import BloomModel | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
@@ -107,6 +107,7 @@ else: | |||||
'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
'bloom': ['BloomModel'], | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,19 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .backbone import BloomModel | |||||
else: | |||||
_import_structure = { | |||||
'backbone': ['BloomModel'], | |||||
} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -4,10 +4,10 @@ from transformers import BloomModel as BloomModelTransform | |||||
from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
from modelscope.models.builder import BACKBONES | from modelscope.models.builder import BACKBONES | ||||
from modelscope.utils.constant import Fields | |||||
from modelscope.utils.constant import Tasks | |||||
@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom) | |||||
@BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom) | |||||
class BloomModel(BloomModelTransform): | class BloomModel(BloomModelTransform): | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
@@ -342,6 +342,8 @@ class GPT3Model(PreTrainedModel): | |||||
state_dict_file = os.path.join(pretrained_model_name_or_path, | state_dict_file = os.path.join(pretrained_model_name_or_path, | ||||
ModelFile.TORCH_MODEL_BIN_FILE) | ModelFile.TORCH_MODEL_BIN_FILE) | ||||
state_dict = torch.load(state_dict_file) | state_dict = torch.load(state_dict_file) | ||||
if 'state_dict' in state_dict: | |||||
state_dict = state_dict['state_dict'] | |||||
state_dict = { | state_dict = { | ||||
k.replace('model.language_model', 'language_model'): v | k.replace('model.language_model', 'language_model'): v | ||||
for k, v in state_dict.items() | for k, v in state_dict.items() | ||||
@@ -42,7 +42,7 @@ class GPT3ForTextGeneration(TorchModel): | |||||
""" | """ | ||||
return self.model(**input) | return self.model(**input) | ||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
assert 'input_ids' in input, "generate function must accept 'input_ids' key" | assert 'input_ids' in input, "generate function must accept 'input_ids' key" | ||||
input_ids = input['input_ids'] | input_ids = input['input_ids'] | ||||
if 'attention_mask' in input: | if 'attention_mask' in input: | ||||
@@ -59,8 +59,4 @@ class GPT3ForTextGeneration(TorchModel): | |||||
gen_params['top_k'] = input.pop('top_k', 10) | gen_params['top_k'] = input.pop('top_k', 10) | ||||
gen_params['top_p'] = input.pop('top_p', None) | gen_params['top_p'] = input.pop('top_p', None) | ||||
sample_output = self.model.generate(**gen_params) | sample_output = self.model.generate(**gen_params) | ||||
return { | |||||
OutputKeys.TEXT: | |||||
self.tokenizer.decode(sample_output[0], | |||||
skip_special_tokens=True).replace(' ', '') | |||||
} | |||||
return {'sequences': sample_output[0]} |
@@ -1314,8 +1314,8 @@ class Translator(object): | |||||
return results | return results | ||||
def __call__(self, input_ids: torch.Tensor, | |||||
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, | |||||
**kwargs) -> Dict[str, torch.Tensor]: | |||||
batch = self.Batch( | batch = self.Batch( | ||||
batch_size=input_ids.size()[0], | batch_size=input_ids.size()[0], | ||||
src=input_ids, | src=input_ids, | ||||
@@ -29,22 +29,6 @@ class PalmForTextGeneration(TorchModel): | |||||
self.tokenizer = self.model.tokenizer | self.tokenizer = self.model.tokenizer | ||||
self.generator = Translator(self.model) | self.generator = Translator(self.model) | ||||
def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | |||||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', | |||||
''), | |||||
(r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), | |||||
('[CLS]', ''), ('[UNK]', ''), (' ', '')) | |||||
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), | |||||
('<pad>', ''), ('<s>', ''), ('</s>', ''), | |||||
('<unk>', ' '), ('<q>', '. ')) | |||||
replace_tokens = replace_tokens_roberta \ | |||||
if self.model.config.encoder == 'roberta' else replace_tokens_bert | |||||
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | |||||
for _old, _new in replace_tokens: | |||||
strings = [s.replace(_old, _new) for s in strings] | |||||
return strings | |||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
"""return the result by the model | """return the result by the model | ||||
@@ -57,29 +41,10 @@ class PalmForTextGeneration(TorchModel): | |||||
{ | { | ||||
'loss': Tensor([12.34]), # loss for backward | 'loss': Tensor([12.34]), # loss for backward | ||||
} | } | ||||
or | |||||
{ | |||||
'preds': List["hello word"...] # the predicted strings | |||||
'tgts': List["hello world"...] # target strings | |||||
} | |||||
""" | """ | ||||
if self.training: | |||||
return self.model(**input) | |||||
else: | |||||
outputs = self.generator(input['input_ids'], | |||||
input['attention_mask']) | |||||
preds = outputs['predictions'] | |||||
pred_ids_list = [ | |||||
pred_batch[0].cpu().numpy().tolist() for pred_batch in preds | |||||
] | |||||
tgt_ids_list = input['labels'].cpu().numpy().tolist() | |||||
return { | |||||
'preds': self._evaluate_postprocess(pred_ids_list), | |||||
'tgts': self._evaluate_postprocess(tgt_ids_list) | |||||
} | |||||
return self.model(**input) | |||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
outputs = self.generator(**input) | outputs = self.generator(**input) | ||||
preds = outputs['predictions'] | preds = outputs['predictions'] | ||||
pred_ids_list = [preds[0][0].cpu().numpy().tolist()] | |||||
return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]} | |||||
return {'sequences': [pred[0] for pred in preds]} |
@@ -563,6 +563,18 @@ class MsDataset: | |||||
self._hf_ds.reset_format() | self._hf_ds.reset_format() | ||||
return self._hf_ds | return self._hf_ds | ||||
def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset: | |||||
""" | |||||
Rename columns and return the underlying hf dataset directly | |||||
TODO: support native MsDataset column rename. | |||||
Args: | |||||
column_mapping: the mapping of the original and new column names | |||||
Returns: | |||||
underlying hf dataset | |||||
""" | |||||
self._hf_ds.reset_format() | |||||
return self._hf_ds.rename_columns(column_mapping) | |||||
@staticmethod | @staticmethod | ||||
def upload(object_name: str, | def upload(object_name: str, | ||||
local_file_path: str, | local_file_path: str, | ||||
@@ -26,10 +26,6 @@ class FaqQuestionAnsweringPipeline(Pipeline): | |||||
if preprocessor is None: | if preprocessor is None: | ||||
preprocessor = Preprocessor.from_pretrained( | preprocessor = Preprocessor.from_pretrained( | ||||
model.model_dir, **kwargs) | model.model_dir, **kwargs) | ||||
if preprocessor is None: | |||||
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||||
preprocessor = FaqQuestionAnsweringPreprocessor( | |||||
model.model_dir, **kwargs) | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
def _sanitize_parameters(self, **pipeline_parameters): | def _sanitize_parameters(self, **pipeline_parameters): | ||||
@@ -53,7 +53,7 @@ class TextGenerationPipeline(Pipeline): | |||||
model = model if isinstance(model, | model = model if isinstance(model, | ||||
Model) else Model.from_pretrained(model) | Model) else Model.from_pretrained(model) | ||||
cfg = read_config(model.model_dir) | cfg = read_config(model.model_dir) | ||||
self.postprocessor = cfg.pop('postprocessor', None) | |||||
self.postprocessor = cfg.pop('postprocessor', 'decode') | |||||
if preprocessor is None: | if preprocessor is None: | ||||
preprocessor_cfg = cfg.preprocessor | preprocessor_cfg = cfg.preprocessor | ||||
preprocessor_cfg.update({ | preprocessor_cfg.update({ | ||||
@@ -78,8 +78,37 @@ class TextGenerationPipeline(Pipeline): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
return self.model.generate(inputs, **forward_params) | return self.model.generate(inputs, **forward_params) | ||||
def sentence_piece(self, inputs) -> Dict[str, Tensor]: | |||||
return self.preprocessor.tokenizer.decode(inputs.tolist()[0]) | |||||
def _is_chinese_char(self, word: str): | |||||
chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》') | |||||
return len(word) == 1 \ | |||||
and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||||
def _remove_space_between_chinese_chars(self, decoded: str): | |||||
old_word_list = decoded.split(' ') | |||||
new_word_list = [] | |||||
start = -1 | |||||
for i, word in enumerate(old_word_list): | |||||
if self._is_chinese_char(word): | |||||
if start == -1: | |||||
start = i | |||||
else: | |||||
if start != -1: | |||||
new_word_list.append(''.join(old_word_list[start:i])) | |||||
start = -1 | |||||
new_word_list.append(word) | |||||
if start != -1: | |||||
new_word_list.append(''.join(old_word_list[start:])) | |||||
return ' '.join(new_word_list) | |||||
def decode(self, inputs) -> str: | |||||
tokenizer = self.preprocessor.tokenizer | |||||
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||||
def roberta(self, inputs) -> str: | |||||
tokenizer = self.preprocessor.tokenizer | |||||
decoded = tokenizer.decode(inputs.tolist()) | |||||
return decoded.replace('<q>', '. ').replace('<mask>', | |||||
'. ').replace('</s>', '') | |||||
def postprocess(self, inputs: Dict[str, Tensor], | def postprocess(self, inputs: Dict[str, Tensor], | ||||
**postprocess_params) -> Dict[str, str]: | **postprocess_params) -> Dict[str, str]: | ||||
@@ -91,7 +120,9 @@ class TextGenerationPipeline(Pipeline): | |||||
Returns: | Returns: | ||||
Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
""" | """ | ||||
return inputs if self.postprocessor is None else { | |||||
OutputKeys.TEXT: | |||||
getattr(self, self.postprocessor.replace('-', '_'))(inputs) | |||||
} | |||||
inputs = inputs['sequences'] | |||||
if isinstance(inputs, list): | |||||
inputs = inputs[0] | |||||
decoded = getattr(self, self.postprocessor)(inputs) | |||||
text = self._remove_space_between_chinese_chars(decoded) | |||||
return {OutputKeys.TEXT: text} |
@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Any, Dict, Optional, Sequence | from typing import Any, Dict, Optional, Sequence | ||||
from modelscope.utils.config import Config | |||||
from modelscope.metainfo import Models, Preprocessors | |||||
from modelscope.utils.config import Config, ConfigDict | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks | from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks | ||||
from modelscope.utils.hub import read_config, snapshot_download | from modelscope.utils.hub import read_config, snapshot_download | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
@@ -12,6 +13,112 @@ from .builder import build_preprocessor | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
PREPROCESSOR_MAP = { | |||||
# nlp | |||||
# bart | |||||
(Models.bart, Tasks.text_error_correction): | |||||
Preprocessors.text_error_correction, | |||||
# bert | |||||
(Models.bert, Tasks.backbone): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.document_segmentation): | |||||
Preprocessors.document_segmentation, | |||||
(Models.bert, Tasks.fill_mask): | |||||
Preprocessors.fill_mask, | |||||
(Models.bert, Tasks.sentence_embedding): | |||||
Preprocessors.sentence_embedding, | |||||
(Models.bert, Tasks.text_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.nli): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.sentiment_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.sentence_similarity): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.zero_shot_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.bert, Tasks.text_ranking): | |||||
Preprocessors.text_ranking, | |||||
(Models.bert, Tasks.part_of_speech): | |||||
Preprocessors.token_cls_tokenizer, | |||||
(Models.bert, Tasks.token_classification): | |||||
Preprocessors.token_cls_tokenizer, | |||||
(Models.bert, Tasks.word_segmentation): | |||||
Preprocessors.token_cls_tokenizer, | |||||
# bloom | |||||
(Models.bloom, Tasks.backbone): | |||||
Preprocessors.text_gen_tokenizer, | |||||
# gpt_neo | |||||
# gpt_neo may have different preprocessors, but now only one | |||||
(Models.gpt_neo, Tasks.backbone): | |||||
Preprocessors.sentence_piece, | |||||
# gpt3 has different preprocessors by different sizes of models, so they are not listed here. | |||||
# palm_v2 | |||||
(Models.palm, Tasks.backbone): | |||||
Preprocessors.text_gen_tokenizer, | |||||
# T5 | |||||
(Models.T5, Tasks.backbone): | |||||
Preprocessors.text2text_gen_preprocessor, | |||||
(Models.T5, Tasks.text2text_generation): | |||||
Preprocessors.text2text_gen_preprocessor, | |||||
# deberta_v2 | |||||
(Models.deberta_v2, Tasks.backbone): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.deberta_v2, Tasks.fill_mask): | |||||
Preprocessors.fill_mask, | |||||
# ponet | |||||
(Models.ponet, Tasks.fill_mask): | |||||
Preprocessors.fill_mask_ponet, | |||||
# structbert | |||||
(Models.structbert, Tasks.backbone): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.fill_mask): | |||||
Preprocessors.fill_mask, | |||||
(Models.structbert, Tasks.faq_question_answering): | |||||
Preprocessors.faq_question_answering_preprocessor, | |||||
(Models.structbert, Tasks.text_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.nli): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.sentiment_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.sentence_similarity): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.zero_shot_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.structbert, Tasks.part_of_speech): | |||||
Preprocessors.token_cls_tokenizer, | |||||
(Models.structbert, Tasks.token_classification): | |||||
Preprocessors.token_cls_tokenizer, | |||||
(Models.structbert, Tasks.word_segmentation): | |||||
Preprocessors.token_cls_tokenizer, | |||||
# veco | |||||
(Models.veco, Tasks.backbone): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.veco, Tasks.fill_mask): | |||||
Preprocessors.fill_mask, | |||||
(Models.veco, Tasks.text_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.veco, Tasks.nli): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.veco, Tasks.sentiment_classification): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
(Models.veco, Tasks.sentence_similarity): | |||||
Preprocessors.sen_cls_tokenizer, | |||||
# space | |||||
} | |||||
class Preprocessor(ABC): | class Preprocessor(ABC): | ||||
@@ -56,37 +163,59 @@ class Preprocessor(ABC): | |||||
if 'task' in kwargs: | if 'task' in kwargs: | ||||
task = kwargs.pop('task') | task = kwargs.pop('task') | ||||
field_name = Tasks.find_field_by_task(task) | field_name = Tasks.find_field_by_task(task) | ||||
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val' | |||||
if not hasattr(cfg, 'preprocessor'): | if not hasattr(cfg, 'preprocessor'): | ||||
logger.error('No preprocessor field found in cfg.') | logger.error('No preprocessor field found in cfg.') | ||||
return None | |||||
sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val' | |||||
preprocessor_cfg = ConfigDict() | |||||
else: | |||||
preprocessor_cfg = cfg.preprocessor | |||||
if 'type' not in cfg.preprocessor: | |||||
if sub_key in cfg.preprocessor: | |||||
sub_cfg = getattr(cfg.preprocessor, sub_key) | |||||
if 'type' not in preprocessor_cfg: | |||||
if sub_key in preprocessor_cfg: | |||||
sub_cfg = getattr(preprocessor_cfg, sub_key) | |||||
else: | else: | ||||
logger.error( | logger.error( | ||||
f'No {sub_key} key and type key found in ' | f'No {sub_key} key and type key found in ' | ||||
f'preprocessor domain of configuration.json file.') | f'preprocessor domain of configuration.json file.') | ||||
return None | |||||
sub_cfg = preprocessor_cfg | |||||
else: | else: | ||||
sub_cfg = cfg.preprocessor | |||||
sub_cfg = preprocessor_cfg | |||||
if len(sub_cfg): | |||||
sub_cfg.update({'model_dir': model_dir}) | |||||
sub_cfg.update(kwargs) | |||||
if 'type' in sub_cfg: | |||||
if isinstance(sub_cfg, Sequence): | if isinstance(sub_cfg, Sequence): | ||||
# TODO: for Sequence, need adapt to `mode` and `mode_dir` args, | # TODO: for Sequence, need adapt to `mode` and `mode_dir` args, | ||||
# and add mode for Compose or other plans | # and add mode for Compose or other plans | ||||
raise NotImplementedError('Not supported yet!') | raise NotImplementedError('Not supported yet!') | ||||
sub_cfg = deepcopy(sub_cfg) | sub_cfg = deepcopy(sub_cfg) | ||||
sub_cfg.update({'model_dir': model_dir}) | |||||
sub_cfg.update(kwargs) | |||||
preprocessor = build_preprocessor(sub_cfg, field_name) | preprocessor = build_preprocessor(sub_cfg, field_name) | ||||
else: | else: | ||||
logger.error( | logger.error( | ||||
f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, ' | f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, ' | ||||
f'please check the preprocessor field in the configuration.json file.' | |||||
f'current config: {sub_cfg}. trying to build by task and model information.' | |||||
) | ) | ||||
return None | |||||
model_cfg = getattr(cfg, 'model', ConfigDict()) | |||||
model_type = model_cfg.type if hasattr( | |||||
model_cfg, 'type') else getattr(model_cfg, 'model_type', None) | |||||
if task is None or model_type is None: | |||||
logger.error( | |||||
f'Find task: {task}, model type: {model_type}. ' | |||||
f'Insufficient information to build preprocessor, skip building preprocessor' | |||||
) | |||||
return None | |||||
if (model_type, task) not in PREPROCESSOR_MAP: | |||||
logger.error( | |||||
f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, ' | |||||
f'skip building preprocessor.') | |||||
return None | |||||
sub_cfg = ConfigDict({ | |||||
'type': PREPROCESSOR_MAP[(model_type, task)], | |||||
**sub_cfg | |||||
}) | |||||
preprocessor = build_preprocessor(sub_cfg, field_name) | |||||
preprocessor.mode = preprocessor_mode | preprocessor.mode = preprocessor_mode | ||||
return preprocessor | return preprocessor |
@@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
]) | ]) | ||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
image = data['image'] if isinstance( | |||||
data['image'], Image.Image) else load_image(data['image']) | |||||
if self.mode == ModeKeys.TRAIN: | |||||
return self._build_train_sample(data) | |||||
else: | |||||
return self._build_infer_sample(data) | |||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
sample = self._build_infer_sample(data) | |||||
target = data[self.column_map['text']] | |||||
target = target.translate(self.transtab).strip() | |||||
target_token_list = target.strip().split() | |||||
target = ' '.join(target_token_list[:self.max_tgt_length]) | |||||
sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
sample['prev_output_tokens'] = torch.cat( | |||||
[self.bos_item, sample['target'][:-1]]) | |||||
return sample | |||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
image = self.get_img_pil(data[self.column_map['image']]) | |||||
patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | ||||
inputs = self.tokenize_text(prompt) | inputs = self.tokenize_text(prompt) | ||||
@@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
'patch_image': patch_image, | 'patch_image': patch_image, | ||||
'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
} | } | ||||
if 'text' in self.column_map and self.column_map['text'] in data: | |||||
sample['label'] = data[self.column_map['text']] | |||||
return sample | return sample |
@@ -129,9 +129,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
def train_step(self, model, inputs): | def train_step(self, model, inputs): | ||||
model.train() | model.train() | ||||
model_outputs = model.forward(inputs) | |||||
loss, sample_size, logging_output = self.criterion( | |||||
model_outputs, inputs) | |||||
loss, sample_size, logging_output = self.criterion(model, inputs) | |||||
train_outputs = {'loss': loss} | train_outputs = {'loss': loss} | ||||
# add model output info to log | # add model output info to log | ||||
if 'log_vars' not in train_outputs: | if 'log_vars' not in train_outputs: | ||||
@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
self.padding_idx = args.tokenizer.pad_token_id | self.padding_idx = args.tokenizer.pad_token_id | ||||
self.args = args | self.args = args | ||||
def forward(self, output, sample, update_num=0, reduce=True): | |||||
def forward(self, model, sample, update_num=0, reduce=True): | |||||
"""Compute the loss for the given sample. | """Compute the loss for the given sample. | ||||
Returns a tuple with three elements: | Returns a tuple with three elements: | ||||
@@ -131,11 +131,16 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
2) the sample size, which is used as the denominator for the gradient | 2) the sample size, which is used as the denominator for the gradient | ||||
3) logging outputs to display while training | 3) logging outputs to display while training | ||||
""" | """ | ||||
if 'labels' in sample: | |||||
del sample['labels'] | |||||
if 'samples' in sample: | |||||
del sample['samples'] | |||||
if self.use_rdrop: | if self.use_rdrop: | ||||
construct_rdrop_sample(sample) | construct_rdrop_sample(sample) | ||||
output = model.model(**sample['net_input']) | |||||
loss, nll_loss, ntokens = self.compute_loss( | loss, nll_loss, ntokens = self.compute_loss( | ||||
output, sample, update_num, reduce=reduce) | |||||
output.logits, sample, update_num, reduce=reduce) | |||||
sample_size = ( | sample_size = ( | ||||
sample['target'].size(0) if self.sentence_avg else ntokens) | sample['target'].size(0) if self.sentence_avg else ntokens) | ||||
logging_output = { | logging_output = { | ||||
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
} | } | ||||
return loss, sample_size, logging_output | return loss, sample_size, logging_output | ||||
def get_lprobs_and_target(self, net_output, sample): | |||||
def get_lprobs_and_target(self, logits, sample): | |||||
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | ||||
'conf'] is not None else 1 | 'conf'] is not None else 1 | ||||
constraint_masks = None | constraint_masks = None | ||||
if 'constraint_masks' in sample and sample[ | if 'constraint_masks' in sample and sample[ | ||||
'constraint_masks'] is not None: | 'constraint_masks'] is not None: | ||||
constraint_masks = sample['constraint_masks'] | constraint_masks = sample['constraint_masks'] | ||||
net_output[0].masked_fill_(~constraint_masks, -math.inf) | |||||
logits.masked_fill_(~constraint_masks, -math.inf) | |||||
if self.constraint_start is not None and self.constraint_end is not None: | if self.constraint_start is not None and self.constraint_end is not None: | ||||
net_output[0][:, :, 4:self.constraint_start] = -math.inf | |||||
net_output[0][:, :, self.constraint_end:] = -math.inf | |||||
lprobs = F.log_softmax( | |||||
net_output[0], dim=-1, dtype=torch.float32) * conf | |||||
logits[:, :, 4:self.constraint_start] = -math.inf | |||||
logits[:, :, self.constraint_end:] = -math.inf | |||||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf | |||||
target = sample['target'] | target = sample['target'] | ||||
if self.ignore_prefix_size > 0: | if self.ignore_prefix_size > 0: | ||||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | ||||
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
return lprobs.view(-1, | return lprobs.view(-1, | ||||
lprobs.size(-1)), target.view(-1), constraint_masks | lprobs.size(-1)), target.view(-1), constraint_masks | ||||
def compute_loss(self, net_output, sample, update_num, reduce=True): | |||||
def compute_loss(self, logits, sample, update_num, reduce=True): | |||||
lprobs, target, constraint_masks = self.get_lprobs_and_target( | lprobs, target, constraint_masks = self.get_lprobs_and_target( | ||||
net_output, sample) | |||||
logits, sample) | |||||
if constraint_masks is not None: | if constraint_masks is not None: | ||||
constraint_masks = constraint_masks[target != self.padding_idx] | constraint_masks = constraint_masks[target != self.padding_idx] | ||||
lprobs = lprobs[target != self.padding_idx] | lprobs = lprobs[target != self.padding_idx] | ||||
@@ -7,11 +7,13 @@ if TYPE_CHECKING: | |||||
from .sequence_classification_trainer import SequenceClassificationTrainer | from .sequence_classification_trainer import SequenceClassificationTrainer | ||||
from .csanmt_translation_trainer import CsanmtTranslationTrainer | from .csanmt_translation_trainer import CsanmtTranslationTrainer | ||||
from .text_ranking_trainer import TextRankingTrainer | from .text_ranking_trainer import TextRankingTrainer | ||||
from .text_generation_trainer import TextGenerationTrainer | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'sequence_classification_trainer': ['SequenceClassificationTrainer'], | 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | ||||
'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | ||||
'text_ranking_trainer': ['TextRankingTrainer'] | |||||
'text_ranking_trainer': ['TextRankingTrainer'], | |||||
'text_generation_trainer': ['TextGenerationTrainer'], | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,36 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from collections.abc import Mapping | |||||
import torch | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.trainers import NlpEpochBasedTrainer | |||||
from modelscope.trainers.builder import TRAINERS | |||||
from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
@TRAINERS.register_module(module_name=Trainers.text_generation_trainer) | |||||
class TextGenerationTrainer(NlpEpochBasedTrainer): | |||||
def _decode(self, tokens): | |||||
tokenizer = self.eval_preprocessor.tokenizer | |||||
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | |||||
def evaluation_step(self, data): | |||||
model = self.model | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
if isinstance( | |||||
data, | |||||
Mapping) and not func_receive_dict_inputs(model.generate): | |||||
result = model.generate(**data) | |||||
else: | |||||
result = model.generate(data) | |||||
result['preds'] = [self._decode(seq) for seq in result['sequences']] | |||||
data['tgts'] = [self._decode(seq) for seq in data['labels']] | |||||
assert len(result['preds']) == len(data['tgts']) | |||||
return result |
@@ -855,6 +855,28 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.invoke_hook(TrainerStages.after_run) | self.invoke_hook(TrainerStages.after_run) | ||||
def evaluation_step(self, data): | |||||
"""Perform a training step on a batch of inputs. | |||||
Subclass and override to inject custom behavior. | |||||
""" | |||||
model = self.model | |||||
model.eval() | |||||
if is_parallel(model): | |||||
receive_dict_inputs = func_receive_dict_inputs( | |||||
model.module.forward) | |||||
else: | |||||
receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||||
with torch.no_grad(): | |||||
if isinstance(data, Mapping) and not receive_dict_inputs: | |||||
result = model.forward(**data) | |||||
else: | |||||
result = model.forward(data) | |||||
return result | |||||
def evaluation_loop(self, data_loader, metric_classes): | def evaluation_loop(self, data_loader, metric_classes): | ||||
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`. | """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | ||||
@@ -862,7 +884,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
if self._dist: | if self._dist: | ||||
from modelscope.trainers.utils.inference import multi_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test | ||||
metric_values = multi_gpu_test( | metric_values = multi_gpu_test( | ||||
self.model, | |||||
self, | |||||
data_loader, | data_loader, | ||||
device=self.device, | device=self.device, | ||||
tmpdir=None, | tmpdir=None, | ||||
@@ -872,7 +894,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
else: | else: | ||||
from modelscope.trainers.utils.inference import single_gpu_test | from modelscope.trainers.utils.inference import single_gpu_test | ||||
metric_values = single_gpu_test( | metric_values = single_gpu_test( | ||||
self.model, | |||||
self, | |||||
data_loader, | data_loader, | ||||
device=self.device, | device=self.device, | ||||
metric_classes=metric_classes, | metric_classes=metric_classes, | ||||
@@ -4,29 +4,25 @@ import logging | |||||
import os | import os | ||||
import pickle | import pickle | ||||
import shutil | import shutil | ||||
import time | |||||
from collections.abc import Mapping | |||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
from tqdm import tqdm | from tqdm import tqdm | ||||
from modelscope.trainers.parallel.utils import is_parallel | |||||
from modelscope.utils.data_utils import to_device | from modelscope.utils.data_utils import to_device | ||||
from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | ||||
make_tmp_dir) | make_tmp_dir) | ||||
def single_gpu_test(model, | |||||
def single_gpu_test(trainer, | |||||
data_loader, | data_loader, | ||||
device, | device, | ||||
metric_classes=None, | metric_classes=None, | ||||
data_loader_iters=None): | data_loader_iters=None): | ||||
"""Test model with a single gpu. | |||||
"""Test model in EpochBasedTrainer with a single gpu. | |||||
Args: | Args: | ||||
model (nn.Module): Model to be tested. | |||||
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. | |||||
data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
device (str | torch.device): The target device for the data. | device (str | torch.device): The target device for the data. | ||||
metric_classes (List): List of Metric class that uses to collect metrics | metric_classes (List): List of Metric class that uses to collect metrics | ||||
@@ -35,7 +31,6 @@ def single_gpu_test(model, | |||||
Returns: | Returns: | ||||
list: The prediction results. | list: The prediction results. | ||||
""" | """ | ||||
model.eval() | |||||
dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
progress_with_iters = False | progress_with_iters = False | ||||
if data_loader_iters is None: | if data_loader_iters is None: | ||||
@@ -55,12 +50,7 @@ def single_gpu_test(model, | |||||
with tqdm(total=data_len, desc=desc) as pbar: | with tqdm(total=data_len, desc=desc) as pbar: | ||||
for i, data in enumerate(data_loader): | for i, data in enumerate(data_loader): | ||||
data = to_device(data, device) | data = to_device(data, device) | ||||
with torch.no_grad(): | |||||
if isinstance(data, Mapping) and not func_receive_dict_inputs( | |||||
model.forward): | |||||
result = model.forward(**data) | |||||
else: | |||||
result = model.forward(data) | |||||
result = trainer.evaluation_step(data) | |||||
if metric_classes is not None: | if metric_classes is not None: | ||||
for metric_cls in metric_classes: | for metric_cls in metric_classes: | ||||
metric_cls.add(result, data) | metric_cls.add(result, data) | ||||
@@ -88,14 +78,14 @@ def single_gpu_test(model, | |||||
return metric_values | return metric_values | ||||
def multi_gpu_test(model, | |||||
def multi_gpu_test(trainer, | |||||
data_loader, | data_loader, | ||||
device, | device, | ||||
tmpdir=None, | tmpdir=None, | ||||
gpu_collect=False, | gpu_collect=False, | ||||
metric_classes=None, | metric_classes=None, | ||||
data_loader_iters_per_gpu=None): | data_loader_iters_per_gpu=None): | ||||
"""Test model with multiple gpus. | |||||
"""Test model in EpochBasedTrainer with multiple gpus. | |||||
This method tests model with multiple gpus and collects the results | This method tests model with multiple gpus and collects the results | ||||
under two different modes: gpu and cpu modes. By setting | under two different modes: gpu and cpu modes. By setting | ||||
@@ -104,7 +94,7 @@ def multi_gpu_test(model, | |||||
different gpus to ``tmpdir`` and collects them by the rank 0 worker. | different gpus to ``tmpdir`` and collects them by the rank 0 worker. | ||||
Args: | Args: | ||||
model (nn.Module): Model to be tested. | |||||
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. | |||||
data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
device: (str | torch.device): The target device for the data. | device: (str | torch.device): The target device for the data. | ||||
tmpdir (str): Path of directory to save the temporary results from | tmpdir (str): Path of directory to save the temporary results from | ||||
@@ -115,7 +105,6 @@ def multi_gpu_test(model, | |||||
Returns: | Returns: | ||||
list: The prediction results. | list: The prediction results. | ||||
""" | """ | ||||
model.eval() | |||||
results = [] | results = [] | ||||
data_list = [] | data_list = [] | ||||
dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
@@ -138,21 +127,12 @@ def multi_gpu_test(model, | |||||
data_len = data_loader_iters_per_gpu * world_size | data_len = data_loader_iters_per_gpu * world_size | ||||
desc = 'Total test iterations with multi gpus' | desc = 'Total test iterations with multi gpus' | ||||
if is_parallel(model): | |||||
receive_dict_inputs = func_receive_dict_inputs(model.module.forward) | |||||
else: | |||||
receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||||
count = 0 | count = 0 | ||||
with tqdm(total=data_len, desc=desc) as pbar: | with tqdm(total=data_len, desc=desc) as pbar: | ||||
for i, data in enumerate(data_loader): | for i, data in enumerate(data_loader): | ||||
data = to_device(data, device) | data = to_device(data, device) | ||||
data_list.append(data) | data_list.append(data) | ||||
with torch.no_grad(): | |||||
if isinstance(data, Mapping) and not receive_dict_inputs: | |||||
result = model.forward(**data) | |||||
else: | |||||
result = model.forward(data) | |||||
result = trainer.evaluation_step(data) | |||||
results.append(result) | results.append(result) | ||||
if isinstance(data, dict): | if isinstance(data, dict): | ||||
@@ -7,7 +7,7 @@ import uuid | |||||
from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
from modelscope.hub.constants import Licenses, ModelVisibility | from modelscope.hub.constants import Licenses, ModelVisibility | ||||
from modelscope.hub.errors import HTTPError, NotLoginException | |||||
from modelscope.hub.errors import GitError, HTTPError, NotLoginException | |||||
from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
@@ -97,6 +97,17 @@ class HubUploadTest(unittest.TestCase): | |||||
revision='new_revision/version1') | revision='new_revision/version1') | ||||
assert os.path.exists(os.path.join(add4_path, 'add4.py')) | assert os.path.exists(os.path.join(add4_path, 'add4.py')) | ||||
shutil.rmtree(self.repo_path, ignore_errors=True) | shutil.rmtree(self.repo_path, ignore_errors=True) | ||||
assert os.path.exists(os.path.join(self.finetune_path, 'add3.py')) | |||||
os.remove(os.path.join(self.finetune_path, 'add3.py')) | |||||
self.api.push_model( | |||||
model_id=self.create_model_name, | |||||
model_dir=self.finetune_path, | |||||
revision='new_revision/version1') | |||||
Repository( | |||||
model_dir=self.repo_path, | |||||
clone_from=self.create_model_name, | |||||
revision='new_revision/version1') | |||||
assert not os.path.exists(os.path.join(self.repo_path, 'add3.py')) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_upload_non_exists_repo(self): | def test_upload_non_exists_repo(self): | ||||
@@ -133,7 +144,7 @@ class HubUploadTest(unittest.TestCase): | |||||
def test_upload_invalid_repo(self): | def test_upload_invalid_repo(self): | ||||
logger.info('test upload to invalid repo!') | logger.info('test upload to invalid repo!') | ||||
self.api.login(TEST_ACCESS_TOKEN1) | self.api.login(TEST_ACCESS_TOKEN1) | ||||
with self.assertRaises(HTTPError): | |||||
with self.assertRaises((HTTPError, GitError)): | |||||
self.api.push_model( | self.api.push_model( | ||||
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), | ||||
model_dir=self.finetune_path, | model_dir=self.finetune_path, | ||||
@@ -38,7 +38,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' | self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' | ||||
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' | self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' | ||||
self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large' | |||||
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' | self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' | ||||
self.gpt3_poetry_input = '天生我材必有用,' | |||||
def run_pipeline_with_model_instance(self, model_id, input): | def run_pipeline_with_model_instance(self, model_id, input): | ||||
model = Model.from_pretrained(model_id) | model = Model.from_pretrained(model_id) | ||||
@@ -115,6 +117,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
self.run_pipeline_with_model_instance(self.palm_model_id_en, | self.run_pipeline_with_model_instance(self.palm_model_id_en, | ||||
self.palm_input_en) | self.palm_input_en) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_gpt_poetry_large_with_model_name(self): | |||||
self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id, | |||||
self.gpt3_poetry_input) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_gpt_base_with_model_instance(self): | def test_gpt_base_with_model_instance(self): | ||||
self.run_pipeline_with_model_instance(self.gpt3_base_model_id, | self.run_pipeline_with_model_instance(self.gpt3_base_model_id, | ||||
@@ -125,6 +132,11 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
self.run_pipeline_with_model_instance(self.gpt3_large_model_id, | self.run_pipeline_with_model_instance(self.gpt3_large_model_id, | ||||
self.gpt3_input) | self.gpt3_input) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_gpt_poetry_large_with_model_instance(self): | |||||
self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id, | |||||
self.gpt3_poetry_input) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run_palm(self): | def test_run_palm(self): | ||||
for model_id, input in ((self.palm_model_id_zh_base, | for model_id, input in ((self.palm_model_id_zh_base, | ||||
@@ -24,17 +24,16 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
datadict = MsDataset.load( | datadict = MsDataset.load( | ||||
'coco_captions_small_slice', | 'coco_captions_small_slice', | ||||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | download_mode=DownloadMode.FORCE_REDOWNLOAD) | ||||
self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | |||||
lambda _: { | |||||
'question': 'what the picture describes?' | |||||
}).rename_column('image:FILE', | |||||
'image').rename_column('answer:Value', 'answer')) | |||||
self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map( | |||||
lambda _: { | |||||
'question': 'what the picture describes?' | |||||
}).rename_column('image:FILE', | |||||
'image').rename_column('answer:Value', 'answer')) | |||||
self.train_dataset = MsDataset( | |||||
datadict['train'].remap_columns({ | |||||
'image:FILE': 'image', | |||||
'answer:Value': 'answer' | |||||
}).map(lambda _: {'question': 'what the picture describes?'})) | |||||
self.test_dataset = MsDataset( | |||||
datadict['test'].remap_columns({ | |||||
'image:FILE': 'image', | |||||
'answer:Value': 'answer' | |||||
}).map(lambda _: {'question': 'what the picture describes?'})) | |||||
self.max_epochs = 2 | self.max_epochs = 2 | ||||
def tearDown(self): | def tearDown(self): | ||||
@@ -59,7 +59,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
trainer = build_trainer( | trainer = build_trainer( | ||||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
name=Trainers.text_generation_trainer, default_args=kwargs) | |||||
trainer.train() | trainer.train() | ||||
results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
@@ -98,7 +98,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
trainer = build_trainer( | trainer = build_trainer( | ||||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
name=Trainers.text_generation_trainer, default_args=kwargs) | |||||
trainer.train() | trainer.train() | ||||
results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
@@ -130,10 +130,16 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
def test_finetune_cnndm(self): | def test_finetune_cnndm(self): | ||||
from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
dataset_dict = MsDataset.load('DuReader_robust-QG') | dataset_dict = MsDataset.load('DuReader_robust-QG') | ||||
train_dataset = dataset_dict['train'].to_hf_dataset() \ | |||||
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||||
eval_dataset = dataset_dict['validation'].to_hf_dataset() \ | |||||
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||||
train_dataset = dataset_dict['train'].remap_columns({ | |||||
'text1': 'src_txt', | |||||
'text2': 'tgt_txt' | |||||
}) | |||||
eval_dataset = dataset_dict['validation'].remap_columns({ | |||||
'text1': | |||||
'src_txt', | |||||
'text2': | |||||
'tgt_txt' | |||||
}) | |||||
num_warmup_steps = 200 | num_warmup_steps = 200 | ||||
os.environ['LOCAL_RANK'] = '0' | os.environ['LOCAL_RANK'] = '0' | ||||
@@ -5,10 +5,10 @@ import unittest | |||||
import json | import json | ||||
from modelscope.metainfo import Metrics, Trainers | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.constant import DownloadMode, ModelFile | |||||
from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
@@ -17,26 +17,27 @@ class TestOfaTrainer(unittest.TestCase): | |||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
self.finetune_cfg = \ | self.finetune_cfg = \ | ||||
{'framework': 'pytorch', | {'framework': 'pytorch', | ||||
'task': 'image-captioning', | |||||
'task': 'ocr-recognition', | |||||
'model': {'type': 'ofa', | 'model': {'type': 'ofa', | ||||
'beam_search': {'beam_size': 5, | 'beam_search': {'beam_size': 5, | ||||
'max_len_b': 16, | |||||
'max_len_b': 64, | |||||
'min_len': 1, | 'min_len': 1, | ||||
'no_repeat_ngram_size': 0}, | 'no_repeat_ngram_size': 0}, | ||||
'seed': 7, | 'seed': 7, | ||||
'max_src_length': 256, | |||||
'language': 'en', | |||||
'max_src_length': 128, | |||||
'language': 'zh', | |||||
'gen_type': 'generation', | 'gen_type': 'generation', | ||||
'patch_image_size': 480, | 'patch_image_size': 480, | ||||
'is_document': False, | |||||
'max_image_size': 480, | 'max_image_size': 480, | ||||
'imagenet_default_mean_and_std': False}, | 'imagenet_default_mean_and_std': False}, | ||||
'pipeline': {'type': 'image-captioning'}, | |||||
'dataset': {'column_map': {'text': 'caption'}}, | |||||
'train': {'work_dir': 'work/ckpts/caption', | |||||
'pipeline': {'type': 'ofa-ocr-recognition'}, | |||||
'dataset': {'column_map': {'text': 'label'}}, | |||||
'train': {'work_dir': 'work/ckpts/recognition', | |||||
# 'launcher': 'pytorch', | # 'launcher': 'pytorch', | ||||
'max_epochs': 1, | 'max_epochs': 1, | ||||
'use_fp16': True, | 'use_fp16': True, | ||||
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, | |||||
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||||
'lr_scheduler': {'name': 'polynomial_decay', | 'lr_scheduler': {'name': 'polynomial_decay', | ||||
'warmup_proportion': 0.01, | 'warmup_proportion': 0.01, | ||||
'lr_end': 1e-07}, | 'lr_end': 1e-07}, | ||||
@@ -57,47 +58,48 @@ class TestOfaTrainer(unittest.TestCase): | |||||
'report_accuracy': False, | 'report_accuracy': False, | ||||
'sample_patch_num': 196, | 'sample_patch_num': 196, | ||||
'sentence_avg': False, | 'sentence_avg': False, | ||||
'use_rdrop': False}, | |||||
'use_rdrop': True}, | |||||
'hooks': [{'type': 'BestCkptSaverHook', | 'hooks': [{'type': 'BestCkptSaverHook', | ||||
'metric_key': 'bleu-4', | |||||
'metric_key': 'accuracy', | |||||
'interval': 100}, | 'interval': 100}, | ||||
{'type': 'TextLoggerHook', 'interval': 1}, | {'type': 'TextLoggerHook', 'interval': 1}, | ||||
{'type': 'IterTimerHook'}, | {'type': 'IterTimerHook'}, | ||||
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | ||||
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | ||||
'metrics': [{'type': 'bleu', | |||||
'eval_tokenized_bleu': False, | |||||
'ref_name': 'labels', | |||||
'hyp_name': 'caption'}]}, | |||||
'metrics': [{'type': 'accuracy'}]}, | |||||
'preprocessor': []} | 'preprocessor': []} | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_trainer_std(self): | def test_trainer_std(self): | ||||
WORKSPACE = './workspace/ckpts/caption' | |||||
WORKSPACE = './workspace/ckpts/recognition' | |||||
os.makedirs(WORKSPACE, exist_ok=True) | os.makedirs(WORKSPACE, exist_ok=True) | ||||
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | ||||
with open(config_file, 'w') as writer: | with open(config_file, 'w') as writer: | ||||
json.dump(self.finetune_cfg, writer) | json.dump(self.finetune_cfg, writer) | ||||
pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' | |||||
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | |||||
args = dict( | args = dict( | ||||
model=pretrained_model, | model=pretrained_model, | ||||
work_dir=WORKSPACE, | work_dir=WORKSPACE, | ||||
train_dataset=MsDataset.load( | train_dataset=MsDataset.load( | ||||
'coco_2014_caption', | |||||
'ocr_fudanvi_zh', | |||||
subset_name='scene', | |||||
namespace='modelscope', | namespace='modelscope', | ||||
split='train[:20]'), | |||||
split='train[:200]', | |||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||||
eval_dataset=MsDataset.load( | eval_dataset=MsDataset.load( | ||||
'coco_2014_caption', | |||||
'ocr_fudanvi_zh', | |||||
subset_name='scene', | |||||
namespace='modelscope', | namespace='modelscope', | ||||
split='validation[:10]'), | |||||
metrics=[Metrics.BLEU], | |||||
split='test[:20]', | |||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||||
cfg_file=config_file) | cfg_file=config_file) | ||||
trainer = build_trainer(name=Trainers.ofa, default_args=args) | trainer = build_trainer(name=Trainers.ofa, default_args=args) | ||||
trainer.train() | trainer.train() | ||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||||
os.listdir(os.path.join(WORKSPACE, 'output'))) | |||||
self.assertIn( | |||||
ModelFile.TORCH_MODEL_BIN_FILE, | |||||
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | |||||
shutil.rmtree(WORKSPACE) | shutil.rmtree(WORKSPACE) | ||||
@@ -12,6 +12,7 @@ from modelscope.metrics.builder import MetricKeys | |||||
from modelscope.metrics.sequence_classification_metric import \ | from modelscope.metrics.sequence_classification_metric import \ | ||||
SequenceClassificationMetric | SequenceClassificationMetric | ||||
from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
from modelscope.trainers import EpochBasedTrainer | |||||
from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | ||||
from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
create_dummy_test_dataset, test_level) | create_dummy_test_dataset, test_level) | ||||
@@ -36,6 +37,12 @@ class DummyModel(nn.Module, Model): | |||||
return dict(logits=x, loss=loss) | return dict(logits=x, loss=loss) | ||||
class DummyTrainer(EpochBasedTrainer): | |||||
def __init__(self, model): | |||||
self.model = model | |||||
def test_func(dist=False): | def test_func(dist=False): | ||||
dummy_model = DummyModel() | dummy_model = DummyModel() | ||||
dataset = dummy_dataset.to_torch_dataset() | dataset = dummy_dataset.to_torch_dataset() | ||||
@@ -62,8 +69,10 @@ def test_func(dist=False): | |||||
else: | else: | ||||
test_func = single_gpu_test | test_func = single_gpu_test | ||||
dummy_trainer = DummyTrainer(dummy_model) | |||||
metric_results = test_func( | metric_results = test_func( | ||||
dummy_model, | |||||
dummy_trainer, | |||||
dummy_loader, | dummy_loader, | ||||
device=device, | device=device, | ||||
metric_classes=[metric_class]) | metric_classes=[metric_class]) | ||||