diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index eacde64a..5923319d 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -176,7 +176,10 @@ class HubApi: """ cookies = ModelScopeConfig.get_cookies() owner_or_group, name = model_id_to_group_owner_name(model_id) - path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' + if revision: + path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' + else: + path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' r = requests.get(path, cookies=cookies, headers=self.headers) handle_http_response(r, logger, cookies, model_id) @@ -447,8 +450,12 @@ class HubApi: Returns: List[dict]: Model file list. """ - path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( - self.endpoint, model_id, revision, recursive) + if revision: + path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( + self.endpoint, model_id, revision, recursive) + else: + path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % ( + self.endpoint, model_id, recursive) cookies = self._check_cookie(use_cookies) if root is not None: path = path + f'&Root={root}' @@ -499,13 +506,14 @@ class HubApi: shutil.rmtree(cache_dir) os.makedirs(cache_dir, exist_ok=True) datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' - r = requests.get(datahub_url) + cookies = ModelScopeConfig.get_cookies() + r = requests.get(datahub_url, cookies=cookies) resp = r.json() datahub_raise_on_error(datahub_url, resp) dataset_id = resp['Data']['Id'] dataset_type = resp['Data']['Type'] datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' - r = requests.get(datahub_url, headers=self.headers) + r = requests.get(datahub_url, cookies=cookies, headers=self.headers) resp = r.json() datahub_raise_on_error(datahub_url, resp) file_list = resp['Data'] @@ -524,7 +532,7 @@ class HubApi: if extension in dataset_meta_format: datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ f'Revision={revision}&FilePath={file_path}' - r = requests.get(datahub_url) + r = requests.get(datahub_url, cookies=cookies) raise_for_http_status(r) local_path = os.path.join(cache_dir, file_path) if os.path.exists(local_path): @@ -569,9 +577,7 @@ class HubApi: datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ f'ststoken?Revision={revision}' - cookies = requests.utils.dict_from_cookiejar(cookies) - r = requests.get( - url=datahub_url, cookies=cookies, headers=self.headers) + r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers) resp = r.json() raise_on_error(resp) return resp['Data'] @@ -582,9 +588,6 @@ class HubApi: f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' cookies = ModelScopeConfig.get_cookies() - if cookies: - cookies = requests.utils.dict_from_cookiejar(cookies) - resp = requests.get(url=url, cookies=cookies) resp = resp.json() raise_on_error(resp) @@ -593,17 +596,48 @@ class HubApi: def on_dataset_download(self, dataset_name: str, namespace: str) -> None: url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' - r = requests.post(url, headers=self.headers) + cookies = ModelScopeConfig.get_cookies() + r = requests.post(url, cookies=cookies, headers=self.headers) raise_for_http_status(r) + def delete_oss_dataset_object(self, object_name: str, dataset_name: str, + namespace: str, revision: str) -> str: + if not object_name or not dataset_name or not namespace or not revision: + raise ValueError('Args cannot be empty!') + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' + + cookies = self.check_local_cookies(use_cookies=True) + resp = requests.delete(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Message'] + return resp + + def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, + namespace: str, revision: str) -> str: + if not object_name or not dataset_name or not namespace or not revision: + raise ValueError('Args cannot be empty!') + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ + f'&Revision={revision}' + + cookies = self.check_local_cookies(use_cookies=True) + resp = requests.delete(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Message'] + return resp + @staticmethod def datahub_remote_call(url): - r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) + cookies = ModelScopeConfig.get_cookies() + r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) resp = r.json() datahub_raise_on_error(url, resp) return resp['Data'] - def check_cookies_upload_data(self, use_cookies) -> CookieJar: + def check_local_cookies(self, use_cookies) -> CookieJar: return self._check_cookie(use_cookies=use_cookies) diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index bfb55e6d..4c4e5dbd 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -63,6 +63,7 @@ def handle_http_post_error(response, url, request_body): except HTTPError as error: logger.error('Request %s with body: %s exception' % (url, request_body)) + logger.error('Response details: %s' % response.content) raise error diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 419ec919..a671ded5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -254,6 +254,7 @@ class Pipelines(object): translation_en_to_de = 'translation_en_to_de' # keep it underscore translation_en_to_ro = 'translation_en_to_ro' # keep it underscore translation_en_to_fr = 'translation_en_to_fr' # keep it underscore + token_classification = 'token-classification' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -305,6 +306,8 @@ class Trainers(object): face_detection_scrfd = 'face-detection-scrfd' card_detection_scrfd = 'card-detection-scrfd' image_inpainting = 'image-inpainting' + referring_video_object_segmentation = 'referring-video-object-segmentation' + image_classification_team = 'image-classification-team' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -422,6 +425,8 @@ class Metrics(object): image_inpainting_metric = 'image-inpainting-metric' # metric for ocr NED = 'ned' + # metric for referring-video-object-segmentation task + referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' class Optimizers(object): diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index c022eaf4..f106f054 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .accuracy_metric import AccuracyMetric from .bleu_metric import BleuMetric from .image_inpainting_metric import ImageInpaintingMetric + from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric else: _import_structure = { @@ -40,6 +41,8 @@ else: 'image_inpainting_metric': ['ImageInpaintingMetric'], 'accuracy_metric': ['AccuracyMetric'], 'bleu_metric': ['BleuMetric'], + 'referring_video_object_segmentation_metric': + ['ReferringVideoObjectSegmentationMetric'], } import sys diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index da3b64c7..2b61c1ae 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -43,6 +43,8 @@ task_default_metrics = { Tasks.visual_question_answering: [Metrics.text_gen_metric], Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], Tasks.image_inpainting: [Metrics.image_inpainting_metric], + Tasks.referring_video_object_segmentation: + [Metrics.referring_video_object_segmentation_metric], } diff --git a/modelscope/metrics/referring_video_object_segmentation_metric.py b/modelscope/metrics/referring_video_object_segmentation_metric.py new file mode 100644 index 00000000..5a0af30b --- /dev/null +++ b/modelscope/metrics/referring_video_object_segmentation_metric.py @@ -0,0 +1,108 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR +from typing import Dict + +import numpy as np +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools.mask import decode +from tqdm import tqdm + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.referring_video_object_segmentation_metric) +class ReferringVideoObjectSegmentationMetric(Metric): + """The metric computation class for movie scene segmentation classes. + """ + + def __init__(self, + ann_file=None, + calculate_precision_and_iou_metrics=True): + self.ann_file = ann_file + self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics + self.preds = [] + + def add(self, outputs: Dict, inputs: Dict): + preds_batch = outputs['pred'] + self.preds.extend(preds_batch) + + def evaluate(self): + coco_gt = COCO(self.ann_file) + coco_pred = coco_gt.loadRes(self.preds) + coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') + coco_eval.params.useCats = 0 + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + ap_labels = [ + 'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', + 'AP 0.5:0.95 M', 'AP 0.5:0.95 L' + ] + ap_metrics = coco_eval.stats[:6] + eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} + if self.calculate_precision_and_iou_metrics: + precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( + coco_gt, coco_pred) + eval_metrics.update({ + f'P@{k}': m + for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) + }) + eval_metrics.update({ + 'overall_iou': overall_iou, + 'mean_iou': mean_iou + }) + + return eval_metrics + + +def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): + outputs = outputs.int() + intersection = (outputs & labels).float().sum( + (1, 2)) # Will be zero if Truth=0 or Prediction=0 + union = (outputs | labels).float().sum( + (1, 2)) # Will be zero if both are 0 + iou = (intersection + EPS) / (union + EPS + ) # EPS is used to avoid division by zero + return iou, intersection, union + + +def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): + print('evaluating precision@k & iou metrics...') + counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} + total_intersection_area = 0 + total_union_area = 0 + ious_list = [] + for instance in tqdm(coco_gt.imgs.keys() + ): # each image_id contains exactly one instance + gt_annot = coco_gt.imgToAnns[instance][0] + gt_mask = decode(gt_annot['segmentation']) + pred_annots = coco_pred.imgToAnns[instance] + pred_annot = sorted( + pred_annots, + key=lambda a: a['score'])[-1] # choose pred with highest score + pred_mask = decode(pred_annot['segmentation']) + iou, intersection, union = compute_iou( + torch.tensor(pred_mask).unsqueeze(0), + torch.tensor(gt_mask).unsqueeze(0)) + iou, intersection, union = iou.item(), intersection.item(), union.item( + ) + for iou_threshold in counters_by_iou.keys(): + if iou > iou_threshold: + counters_by_iou[iou_threshold] += 1 + total_intersection_area += intersection + total_union_area += union + ious_list.append(iou) + num_samples = len(ious_list) + precision_at_k = np.array(list(counters_by_iou.values())) / num_samples + overall_iou = total_intersection_area / total_union_area + mean_iou = np.mean(ious_list) + return precision_at_k, overall_iou, mean_iou diff --git a/modelscope/metrics/sequence_classification_metric.py b/modelscope/metrics/sequence_classification_metric.py index 51a829ef..1fe1c329 100644 --- a/modelscope/metrics/sequence_classification_metric.py +++ b/modelscope/metrics/sequence_classification_metric.py @@ -3,6 +3,7 @@ from typing import Dict import numpy as np +from sklearn.metrics import accuracy_score, f1_score from modelscope.metainfo import Metrics from modelscope.outputs import OutputKeys @@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): preds = np.argmax(preds, axis=1) return { MetricKeys.ACCURACY: - (preds == labels).astype(np.float32).mean().item() + accuracy_score(labels, preds), + MetricKeys.F1: + f1_score( + labels, + preds, + average='micro' if any([label > 1 + for label in labels]) else None), } diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 9bca7cf3..c2d9c6a8 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -2,7 +2,7 @@ from typing import Dict, Iterable, List -from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge from modelscope.metainfo import Metrics @@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) - pred_split = tuple(pred.split(' ') for pred in self.preds) - tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) - bleu_1 = mean( - sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) - for pred, tgt in zip(pred_split, tgt_split)) - bleu_4 = mean( - sentence_bleu([tgt], pred) - for pred, tgt in zip(pred_split, tgt_split)) + + pred_list = [each.strip().split(' ') for each in self.preds] + tgt_list = [[each.strip().split(' ')] for each in self.tgts] + bleu_1 = corpus_bleu( + tgt_list, + pred_list, + weights=(1, 0, 0, 0), + smoothing_function=SmoothingFunction().method3) + bleu_4 = corpus_bleu( + tgt_list, + pred_list, + smoothing_function=SmoothingFunction().method3) return { MetricKeys.ROUGE_1: rouge_1, MetricKeys.ROUGE_L: rouge_l, diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 1246551e..e01d1f05 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -67,8 +67,28 @@ class Model(ABC): cfg_dict: Config = None, device: str = None, **kwargs): - """ Instantiate a model from local directory or remote model repo. Note + """Instantiate a model from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id to be loaded + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional model config. If provided, it will replace + the config read out of the `model_name_or_path` + device(str, `optional`): The device to load the model. + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not + equal to the model saved. + For example, load a `backbone` into a `text-classification` model. + Other kwargs will be directly fed into the `model` key, to replace the default configs. + Returns: + A model instance. + + Examples: + >>> from modelscope.models import Model + >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') """ prefetched = kwargs.get('model_prefetched') if prefetched is not None: diff --git a/modelscope/models/cv/referring_video_object_segmentation/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/__init__.py index 58dbf7b0..4c97bd7b 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/__init__.py +++ b/modelscope/models/cv/referring_video_object_segmentation/__init__.py @@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .model import MovieSceneSegmentation + from .model import ReferringVideoObjectSegmentation else: _import_structure = { - 'model': ['MovieSceneSegmentation'], + 'model': ['ReferringVideoObjectSegmentation'], } import sys diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py index 902a3416..91f7ea91 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/model.py +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -1,4 +1,6 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + import os.path as osp from typing import Any, Dict @@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger -from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, +from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, + ReferYoutubeVOSPostProcess, SetCriterion, + flatten_temporal_batch_dims, nested_tensor_from_videos_list) logger = get_logger() @@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel): params_dict = params_dict['model_state_dict'] self.model.load_state_dict(params_dict, strict=True) - dataset_name = self.cfg.pipeline.dataset_name - if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': - self.postprocessor = A2DSentencesPostProcess() - elif dataset_name == 'ref_youtube_vos': - self.postprocessor = ReferYoutubeVOSPostProcess() + self.set_postprocessor(self.cfg.pipeline.dataset_name) + self.set_criterion() + + def set_device(self, device, name): + self.device = device + self._device_name = name + + def set_postprocessor(self, dataset_name): + if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: + self.postprocessor = A2DSentencesPostProcess() # fine-tune + elif 'ref_youtube_vos' in dataset_name: + self.postprocessor = ReferYoutubeVOSPostProcess() # inference else: assert False, f'postprocessing for dataset: {dataset_name} is not supported' - def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: - return inputs + def forward(self, inputs: Dict[str, Any]): + samples = inputs['samples'] + targets = inputs['targets'] + text_queries = inputs['text_queries'] + + valid_indices = torch.tensor( + [i for i, t in enumerate(targets) if None not in t]) + targets = [targets[i] for i in valid_indices.tolist()] + if self._device_name == 'gpu': + samples = samples.to(self.device) + valid_indices = valid_indices.to(self.device) + if isinstance(text_queries, tuple): + text_queries = list(text_queries) + + outputs = self.model(samples, valid_indices, text_queries) + losses = -1 + if self.training: + loss_dict = self.criterion(outputs, targets) + weight_dict = self.criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] + for k in loss_dict.keys() if k in weight_dict) + + predictions = [] + if not self.training: + outputs.pop('aux_outputs', None) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + processed_outputs = self.postprocessor( + outputs, + resized_padded_sample_size=samples.tensors.shape[-2:], + resized_sample_sizes=[t['size'] for t in targets], + orig_sample_sizes=[t['orig_size'] for t in targets]) + image_ids = [t['image_id'] for t in targets] + predictions = [] + for p, image_id in zip(processed_outputs, image_ids): + for s, m in zip(p['scores'], p['rle_masks']): + predictions.append({ + 'image_id': image_id, + 'category_id': + 1, # dummy label, as categories are not predicted in ref-vos + 'segmentation': m, + 'score': s.item() + }) + + re = dict(pred=predictions, loss=losses) + return re def inference(self, **kwargs): window = kwargs['window'] @@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel): def postprocess(self, inputs: Dict[str, Any], **kwargs): return inputs + + def set_criterion(self): + matcher = HungarianMatcher( + cost_is_referred=self.cfg.matcher.set_cost_is_referred, + cost_dice=self.cfg.matcher.set_cost_dice) + weight_dict = { + 'loss_is_referred': self.cfg.loss.is_referred_loss_coef, + 'loss_dice': self.cfg.loss.dice_loss_coef, + 'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef + } + + if self.cfg.loss.aux_loss: + aux_weight_dict = {} + for i in range(self.cfg.model.num_decoder_layers - 1): + aux_weight_dict.update( + {k + f'_{i}': v + for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + self.criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, + eos_coef=self.cfg.loss.eos_coef) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py index 796bd6f4..fbb75b00 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .misc import nested_tensor_from_videos_list +from .criterion import SetCriterion, flatten_temporal_batch_dims +from .matcher import HungarianMatcher +from .misc import interpolate, nested_tensor_from_videos_list from .mttr import MTTR from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py new file mode 100644 index 00000000..a4d2f0ff --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +import torch +from torch import nn + +from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, + nested_tensor_from_tensor_list) +from .segmentation import dice_loss, sigmoid_focal_loss + + +class SetCriterion(nn.Module): + """ This class computes the loss for MTTR. + The process happens in two steps: + 1) we compute the hungarian assignment between the ground-truth and predicted sequences. + 2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) + """ + + def __init__(self, matcher, weight_dict, eos_coef): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the un-referred category + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + # make sure that only loss functions with non-zero weights are computed: + losses_to_compute = [] + if weight_dict['loss_dice'] > 0 or weight_dict[ + 'loss_sigmoid_focal'] > 0: + losses_to_compute.append('masks') + if weight_dict['loss_is_referred'] > 0: + losses_to_compute.append('is_referred') + self.losses = losses_to_compute + + def forward(self, outputs, targets): + aux_outputs_list = outputs.pop('aux_outputs', None) + # compute the losses for the output of the last decoder layer: + losses = self.compute_criterion( + outputs, targets, losses_to_compute=self.losses) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. + if aux_outputs_list is not None: + aux_losses_to_compute = self.losses.copy() + for i, aux_outputs in enumerate(aux_outputs_list): + losses_dict = self.compute_criterion(aux_outputs, targets, + aux_losses_to_compute) + losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} + losses.update(losses_dict) + + return losses + + def compute_criterion(self, outputs, targets, losses_to_compute): + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs, targets) + + # T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). + # also, indices are repeated so so the same indices can be used for frames of the same video. + T = len(targets) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + # repeat the indices list T times so the same indices can be used for each video frame + indices = T * indices + + # Compute the average number of target masks across all nodes, for normalization purposes + num_masks = sum(len(t['masks']) for t in targets) + num_masks = torch.as_tensor([num_masks], + dtype=torch.float, + device=indices[0][0].device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in losses_to_compute: + losses.update( + self.get_loss( + loss, outputs, targets, indices, num_masks=num_masks)) + return losses + + def loss_is_referred(self, outputs, targets, indices, **kwargs): + device = outputs['pred_is_referred'].device + bs = outputs['pred_is_referred'].shape[0] + pred_is_referred = outputs['pred_is_referred'].log_softmax( + dim=-1) # note that log-softmax is used here + target_is_referred = torch.zeros_like(pred_is_referred) + # extract indices of object queries that where matched with text-referred target objects + query_referred_indices = self._get_query_referred_indices( + indices, targets) + # by default penalize compared to the no-object class (last token) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible_per_frame = torch.stack( + [t['is_ref_inst_visible'] for t in targets]) + ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( + ).squeeze() + # keep only the matched query indices of the frames in which the referred object is visible: + visible_query_referred_indices = query_referred_indices[ + ref_inst_visible_frame_indices] + target_is_referred[ref_inst_visible_frame_indices, + visible_query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[torch.arange(bs), + query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + loss = -(pred_is_referred * target_is_referred).sum(-1) + # apply no-object class weights: + eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) + eos_coef[torch.arange(bs), query_referred_indices] = 1.0 + loss = loss * eos_coef + bs = len(indices) + loss = loss.sum() / bs # sum and normalize the loss by the batch size + losses = {'loss_is_referred': loss} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): + assert 'pred_masks' in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs['pred_masks'] + src_masks = src_masks[src_idx] + masks = [t['masks'] for t in targets] + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate( + src_masks[:, None], + size=target_masks.shape[-2:], + mode='bilinear', + align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + 'loss_sigmoid_focal': + sigmoid_focal_loss(src_masks, target_masks, num_masks), + 'loss_dice': + dice_loss(src_masks, target_masks, num_masks), + } + return losses + + @staticmethod + def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + @staticmethod + def _get_tgt_permutation_idx(indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + @staticmethod + def _get_query_referred_indices(indices, targets): + """ + extract indices of object queries that where matched with text-referred target objects + """ + query_referred_indices = [] + for (query_idxs, target_idxs), target in zip(indices, targets): + ref_query_idx = query_idxs[torch.where( + target_idxs == target['referred_instance_idx'])[0]] + query_referred_indices.append(ref_query_idx) + query_referred_indices = torch.cat(query_referred_indices) + return query_referred_indices + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + 'masks': self.loss_masks, + 'is_referred': self.loss_is_referred, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + +def flatten_temporal_batch_dims(outputs, targets): + for k in outputs.keys(): + if isinstance(outputs[k], torch.Tensor): + outputs[k] = outputs[k].flatten(0, 1) + else: # list + outputs[k] = [i for step_t in outputs[k] for i in step_t] + targets = [ + frame_t_target for step_t in targets for frame_t_target in step_t + ] + return outputs, targets diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py new file mode 100644 index 00000000..4f9b88e5 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py @@ -0,0 +1,163 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# Module to compute the matching cost and solve the corresponding LSAP. + +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from .misc import interpolate, nested_tensor_from_tensor_list + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): + """Creates the matcher + + Params: + cost_is_referred: This is the relative weight of the reference cost in the total matching cost + cost_dice: This is the relative weight of the dice cost in the total matching cost + """ + super().__init__() + self.cost_is_referred = cost_is_referred + self.cost_dice = cost_dice + assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' + + @torch.inference_mode() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: A dict that contains at least these entries: + "pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits + "pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits + + targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict + which contain mask and reference ground truth information for a single frame. + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_masks) + """ + t, bs, num_queries = outputs['pred_masks'].shape[:3] + + # We flatten to compute the cost matrices in a batch + out_masks = outputs['pred_masks'].flatten( + 1, 2) # [t, batch_size * num_queries, mask_h, mask_w] + + # preprocess and concat the target masks + tgt_masks = [[ + m for v in t_step_batch for m in v['masks'].unsqueeze(1) + ] for t_step_batch in targets] + # pad the target masks to a uniform shape + tgt_masks, valid = list( + zip(*[ + nested_tensor_from_tensor_list(t).decompose() + for t in tgt_masks + ])) + tgt_masks = torch.stack(tgt_masks).squeeze(2) + + # upsample predicted masks to target mask size + out_masks = interpolate( + out_masks, + size=tgt_masks.shape[-2:], + mode='bilinear', + align_corners=False) + + # Compute the soft-tokens cost: + if self.cost_is_referred > 0: + cost_is_referred = compute_is_referred_cost(outputs, targets) + else: + cost_is_referred = 0 + + # Compute the DICE coefficient between the masks: + if self.cost_dice > 0: + cost_dice = -dice_coef(out_masks, tgt_masks) + else: + cost_dice = 0 + + # Final cost matrix + C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice + C = C.view(bs, num_queries, -1).cpu() + + num_traj_per_batch = [ + len(v['masks']) for v in targets[0] + ] # number of instance trajectories in each batch + indices = [ + linear_sum_assignment(c[i]) + for i, c in enumerate(C.split(num_traj_per_batch, -1)) + ] + device = out_masks.device + return [(torch.as_tensor(i, dtype=torch.int64, device=device), + torch.as_tensor(j, dtype=torch.int64, device=device)) + for i, j in indices] + + +def dice_coef(inputs, targets, smooth=1.0): + """ + Compute the DICE coefficient, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid().flatten(2).unsqueeze(2) + targets = targets.flatten(2).unsqueeze(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + coef = (numerator + smooth) / (denominator + smooth) + coef = coef.mean( + 0) # average on the temporal dim to get instance trajectory scores + return coef + + +def compute_is_referred_cost(outputs, targets): + pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( + dim=-1) # [t, b*nq, 2] + device = pred_is_referred.device + t = pred_is_referred.shape[0] + # number of instance trajectories in each batch + num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], + device=device) + total_trajectories = num_traj_per_batch.sum() + # note that ref_indices are shared across time steps: + ref_indices = torch.tensor( + [v['referred_instance_idx'] for v in targets[0]], device=device) + # convert ref_indices to fit flattened batch targets: + ref_indices += torch.cat( + (torch.zeros(1, dtype=torch.long, + device=device), num_traj_per_batch.cumsum(0)[:-1])) + # number of instance trajectories in each batch + target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) + # 'no object' class by default (for un-referred objects) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[0][ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible = torch.stack([ + torch.stack([t['is_ref_inst_visible'] for t in t_step]) + for t_step in targets + ]).permute(1, 0) + for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): + is_visible = is_visible.nonzero().squeeze() + target_is_referred[is_visible, + ref_idx, :] = torch.tensor([1.0, 0.0], + device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], + device=device) + cost_is_referred = -(pred_is_referred.unsqueeze(2) + * target_is_referred.unsqueeze(1)).sum(dim=-1).mean( + dim=0) + return cost_is_referred diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py index 8c24e397..39962715 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module): with torch.inference_mode(mode=self.freeze_text_encoder): encoded_text = self.text_encoder(**tokenized_queries) # Transpose memory because pytorch's attention expects sequence first - txt_memory = rearrange(encoded_text.last_hidden_state, - 'b s c -> s b c') + tmp_last_hidden_state = encoded_text.last_hidden_state.clone() + txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') txt_memory = self.txt_proj( txt_memory) # change text embeddings dim to model dim # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py index 9a08ef48..faaf6e10 100644 --- a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py @@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module): # define a parameter table of relative position bias wd, wh, ww = window_size self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) + torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_d = torch.arange(self.window_size[0]) diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py index 87a28a2f..202bdd55 100755 --- a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py @@ -269,8 +269,11 @@ class TinyNAS(nn.Module): the_block_class = block_info['class'] if the_block_class == 'ConvKXBNRELU': if use_focus: - the_block = Focus(block_info['in'], block_info['out'], - block_info['k']) + the_block = Focus( + block_info['in'], + block_info['out'], + block_info['k'], + act=act) else: the_block = ConvKXBNRELU( block_info['in'], diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py index 42a71381..7aff2167 100644 --- a/modelscope/models/cv/tinynas_detection/detector.py +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -6,6 +6,7 @@ import pickle import cv2 import torch +import torch.nn as nn import torchvision from modelscope.metainfo import Models @@ -47,6 +48,7 @@ class SingleStageDetector(TorchModel): self.backbone = build_backbone(self.cfg.model.backbone) self.neck = build_neck(self.cfg.model.neck) self.head = build_head(self.cfg.model.head) + self.apply(self.init_bn) self.load_pretrain_model(model_path) @@ -59,6 +61,12 @@ class SingleStageDetector(TorchModel): new_state_dict[k] = v self.load_state_dict(new_state_dict, strict=True) + def init_bn(self, M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + def inference(self, x): if self.training: diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py index 8110a0f7..655d36d2 100644 --- a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from os import path as osp from typing import Any, Dict import json @@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg from modelscope.models.multi_modal.ofa.generate.search import Sampling from modelscope.models.multi_modal.ofa.generate.utils import move_to_device -from modelscope.utils.constant import Tasks +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks try: from torchvision.transforms import InterpolationMode @@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model): super().__init__(model_dir=model_dir, *args, **kwargs) # Initialize ofa model = OFAModel.from_pretrained(model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) self.model = model.module if hasattr(model, 'module') else model self.tokenizer = OFATokenizer.from_pretrained(model_dir) self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) @@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model): 'gen_code': True, 'constraint_range': '50265,58457' } + if hasattr(self.cfg.model, 'beam_search'): + sg_args.update(self.cfg.model.beam_search) self.generator = sg.SequenceGenerator(**sg_args) def clip_tokenize(self, texts, context_length=77, truncate=False): diff --git a/modelscope/models/nlp/heads/text_generation_head.py b/modelscope/models/nlp/heads/text_generation_head.py index 606d5a1f..ecb02e22 100644 --- a/modelscope/models/nlp/heads/text_generation_head.py +++ b/modelscope/models/nlp/heads/text_generation_head.py @@ -8,7 +8,6 @@ from torch import nn from modelscope.metainfo import Heads from modelscope.models.base import TorchHead from modelscope.models.builder import HEADS -from modelscope.outputs import OutputKeys from modelscope.utils.constant import Tasks @@ -27,9 +26,8 @@ class TextGenerationHead(TorchHead): def forward(self, inputs=None): logits = self.linear(inputs) - return {OutputKeys.LOGITS: logits} + return logits - def compute_loss(self, outputs: Dict[str, torch.Tensor], + def compute_loss(self, logits: torch.Tensor, labels) -> Dict[str, torch.Tensor]: - logits = outputs[OutputKeys.LOGITS] - return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} + return F.cross_entropy(logits, labels) diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py index f17b0f6b..cd8e20cf 100644 --- a/modelscope/models/nlp/task_models/text_generation.py +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -import addict import numpy as np from transformers.modeling_utils import PreTrainedModel @@ -9,7 +8,8 @@ from modelscope.metainfo import TaskModels from modelscope.models.builder import MODELS from modelscope.models.nlp.task_models.task_model import \ SingleBackboneTaskModelBase -from modelscope.outputs import OutputKeys +from modelscope.outputs import (OutputKeys, TextGenerationModelOutput, + TokenGeneratorOutput) from modelscope.utils.constant import Tasks __all__ = ['TaskModelForTextGeneration'] @@ -43,12 +43,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): backbone_outputs = super().forward(input) hidden_states = backbone_outputs[0] - outputs = self.head.forward(hidden_states) + logits = self.head.forward(hidden_states) + loss = None if labels is not None: input[OutputKeys.LABELS] = labels - loss = self.compute_loss(outputs, labels) - outputs.update(loss) - return addict.Dict(outputs) + loss = self.compute_loss(logits, labels) + return TextGenerationModelOutput(logits=logits, loss=loss) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs @@ -76,4 +76,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): def generate(self, inputs, *args, **kwargs): input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs - return super().generate(input_ids, *args, **kwargs) + generate_output = super().generate(input_ids, *args, **kwargs) + if isinstance(generate_output, Dict): + return TokenGeneratorOutput( + sequences=generate_output.sequences, + scores=generate_output.scores, + attentions=generate_output.attentions, + hidden_states=generate_output.hidden_states) + else: + return TokenGeneratorOutput(sequences=generate_output) diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py index 2739bf11..8b523baf 100644 --- a/modelscope/models/nlp/task_models/token_classification.py +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): attentions=outputs.attentions, offset_mapping=input['offset_mapping'], ) - return outputs def extract_logits(self, outputs): return outputs[OutputKeys.LOGITS].cpu().detach() diff --git a/modelscope/models/science/unifold/modules/structure_module.py b/modelscope/models/science/unifold/modules/structure_module.py index 5d4da30b..4872d5c6 100644 --- a/modelscope/models/science/unifold/modules/structure_module.py +++ b/modelscope/models/science/unifold/modules/structure_module.py @@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): pt_att *= pt_att pt_att = pt_att.sum(dim=-1) - head_weights = self.softplus(self.head_weights).view( - *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = self.softplus(self.head_weights).view( # noqa + *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa head_weights = head_weights * math.sqrt( 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) pt_att *= head_weights * (-0.5) diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index e90f397b..0c537df7 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -20,13 +20,15 @@ from modelscope.msdatasets.task_datasets.builder import build_task_dataset from modelscope.msdatasets.utils.dataset_builder import ExternalDataset from modelscope.msdatasets.utils.dataset_utils import ( get_dataset_files, get_target_dataset_structure, load_dataset_builder) +from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager from modelscope.msdatasets.utils.download_utils import DatasetDownloadManager from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager from modelscope.utils.config import ConfigDict from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION, - DatasetFormations, DownloadMode, Hubs) + DatasetFormations, DownloadMode, Hubs, + UploadMode) from modelscope.utils.logger import get_logger logger = get_logger() @@ -576,15 +578,17 @@ class MsDataset: return self._hf_ds.rename_columns(column_mapping) @staticmethod - def upload(object_name: str, - local_file_path: str, - dataset_name: str, - namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, - version: Optional[str] = DEFAULT_DATASET_REVISION, - num_processes: Optional[int] = None, - chunksize: Optional[int] = 1, - filter_hidden_files: Optional[bool] = True) -> None: - """Upload dataset file or directory to the ModelScope Hub. Please login to the ModelScope Hub first. + def upload( + object_name: str, + local_file_path: str, + dataset_name: str, + namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, + version: Optional[str] = DEFAULT_DATASET_REVISION, + num_processes: Optional[int] = None, + chunksize: Optional[int] = 1, + filter_hidden_files: Optional[bool] = True, + upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None: + """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first. Args: object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name @@ -592,7 +596,7 @@ class MsDataset: dataset_name (str): Name of the dataset namespace(str, optional): Namespace of the dataset version: Optional[str]: Version of the dataset - num_processes: Optional[int]: The number of processes used for multi-process uploading. + num_processes: Optional[int]: The number of processes used for multiprocess uploading. This is only applicable when local_file_path is a directory, and we are uploading mutliple-files insided the directory. When None provided, the number returned by os.cpu_count() is used as default. chunksize: Optional[int]: The chunksize of objects to upload. @@ -600,24 +604,34 @@ class MsDataset: using the default value of 1. Available if local_file_path is a directory. filter_hidden_files: Optional[bool]: Whether to filter hidden files. Available if local_file_path is a directory. + upload_mode: Optional[UploadMode]: How to upload objects from local. Default: UploadMode.OVERWRITE, upload + all objects from local, existing remote objects may be overwritten. Returns: None """ + if not object_name: + raise ValueError('object_name cannot be empty!') + _upload_manager = DatasetUploadManager( dataset_name=dataset_name, namespace=namespace, version=version) + upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE) + if os.path.isfile(local_file_path): _upload_manager.upload( - object_name=object_name, local_file_path=local_file_path) + object_name=object_name, + local_file_path=local_file_path, + upload_mode=upload_mode) elif os.path.isdir(local_file_path): _upload_manager.upload_dir( object_dir_name=object_name, local_dir_path=local_file_path, num_processes=num_processes, chunksize=chunksize, - filter_hidden_files=filter_hidden_files) + filter_hidden_files=filter_hidden_files, + upload_mode=upload_mode) else: raise ValueError( f'{local_file_path} is not a valid file path or directory') @@ -672,7 +686,7 @@ class MsDataset: revision of the model you want to clone from. Can be any of a branch, tag or commit hash auth_token(`Optional[str]`): token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter - as the token is already saved when you login the first time, if None, we will use saved token. + as the token is already saved when you log in the first time, if None, we will use saved token. git_path:(`Optional[str]`): The git command line path, if None, we use 'git' force (Optional[bool]): whether to use forced-push. @@ -687,8 +701,29 @@ class MsDataset: revision=revision, auth_token=auth_token, git_path=git_path) - _repo.push( - commit_message=commit_message, - local_branch=revision, - remote_branch=revision, - force=force) + _repo.push(commit_message=commit_message, branch=revision, force=force) + + @staticmethod + def delete(object_name: str, + dataset_name: str, + namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, + version: Optional[str] = DEFAULT_DATASET_REVISION) -> str: + """ Delete object of dataset. Please log in first and make sure you have permission to manage the dataset. + + Args: + object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's + directory, then ends with `/`. + For example: your-data-name.zip, train/001/img_001.png, train/, ... + dataset_name (str): Path or name of the dataset. + namespace(str, optional): Namespace of the dataset. + version (str, optional): Version of the dataset. + + Returns: + res_msg (str): Response message. + + """ + _delete_manager = DatasetDeleteManager( + dataset_name=dataset_name, namespace=namespace, version=version) + resp_msg = _delete_manager.delete(object_name=object_name) + logger.info(f'Object {object_name} successfully removed!') + return resp_msg diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 92764155..043010bf 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from .video_summarization_dataset import VideoSummarizationDataset from .image_inpainting import ImageInpaintingDataset from .text_ranking_dataset import TextRankingDataset + from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset else: _import_structure = { @@ -29,6 +30,8 @@ else: 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], 'image_portrait_enhancement_dataset': ['ImagePortraitEnhancementDataset'], + 'referring_video_object_segmentation': + ['ReferringVideoObjectSegmentationDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..7c1b724e --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .referring_video_object_segmentation_dataset import \ + ReferringVideoObjectSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py new file mode 100644 index 00000000..c90351e9 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py @@ -0,0 +1,361 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + +from glob import glob +from os import path as osp + +import h5py +import json +import numpy as np +import pandas +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from pycocotools.mask import area, encode +from torchvision.io import read_video +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + nested_tensor_from_videos_list +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from . import transformers as T + +LOGGER = get_logger() + + +def get_image_id(video_id, frame_idx, ref_instance_a2d_id): + image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' + return image_id + + +@TASK_DATASETS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + data_cfg = kwargs.get('cfg').data_kwargs + trans_cfg = kwargs.get('cfg').transformers_kwargs + distributed = data_cfg.get('distributed', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + + self.window_size = data_cfg.get('window_size', 8) + self.mask_annotations_dir = osp.join( + self.data_root, 'text_annotations/annotation_with_instances') + self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') + self.subset_type = next(iter(split_config.keys())) + self.text_annotations = self.get_text_annotations( + self.data_root, self.subset_type, distributed) + self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) + self.collator = Collator() + self.ann_file = osp.join( + self.data_root, + data_cfg.get('ann_file', + 'a2d_sentences_test_annotations_in_coco_format.json')) + + # create ground-truth test annotations for the evaluation process if necessary: + if self.subset_type == 'test' and not osp.exists(self.ann_file): + if (distributed and dist.get_rank() == 0) or not distributed: + create_a2d_sentences_ground_truth_test_annotations( + self.data_root, self.subset_type, + self.mask_annotations_dir, self.ann_file) + if distributed: + dist.barrier() + + def __len__(self): + return len(self.text_annotations) + + def __getitem__(self, idx): + text_query, video_id, frame_idx, instance_id = self.text_annotations[ + idx] + + text_query = ' '.join( + text_query.lower().split()) # clean up the text query + + # read the source window frames: + video_frames, _, _ = read_video( + osp.join(self.videos_dir, f'{video_id}.mp4'), + pts_unit='sec') # (T, H, W, C) + # get a window of window_size frames with frame frame_idx in the middle. + # note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx + start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( + self.window_size + 1) // 2 + + # extract the window source frames: + source_frames = [] + for i in range(start_idx, end_idx): + i = min(max(i, 0), + len(video_frames) + - 1) # pad out of range indices with edge frames + source_frames.append( + F.to_pil_image(video_frames[i].permute(2, 0, 1))) + + # read the instance mask: + frame_annot_path = osp.join(self.mask_annotations_dir, video_id, + f'{frame_idx:05d}.h5') + f = h5py.File(frame_annot_path, 'r') + instances = list(f['instance']) + instance_idx = instances.index( + instance_id) # existence was already validated during init + + instance_masks = np.array(f['reMask']) + if len(instances) == 1: + instance_masks = instance_masks[np.newaxis, ...] + instance_masks = torch.tensor(instance_masks).transpose(1, 2) + mask_rles = [encode(mask) for mask in instance_masks.numpy()] + mask_areas = area(mask_rles).astype(np.float) + f.close() + + # create the target dict for the center frame: + target = { + 'masks': instance_masks, + 'orig_size': instance_masks. + shape[-2:], # original frame shape without any augmentations + # size with augmentations, will be changed inside transforms if necessary + 'size': instance_masks.shape[-2:], + 'referred_instance_idx': torch.tensor( + instance_idx), # idx in 'masks' of the text referred instance + 'area': torch.tensor(mask_areas), + 'iscrowd': + torch.zeros(len(instance_masks) + ), # for compatibility with DETR COCO transforms + 'image_id': get_image_id(video_id, frame_idx, instance_id) + } + + # create dummy targets for adjacent frames: + targets = self.window_size * [None] + center_frame_idx = self.window_size // 2 + targets[center_frame_idx] = target + source_frames, targets, text_query = self.transforms( + source_frames, targets, text_query) + return source_frames, targets, text_query + + @staticmethod + def get_text_annotations(root_path, subset, distributed): + saved_annotations_file_path = osp.join( + root_path, f'sentences_single_frame_{subset}_annotations.json') + if osp.exists(saved_annotations_file_path): + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + elif (distributed and dist.get_rank() == 0) or not distributed: + print(f'building a2d sentences {subset} text annotations...') + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = [ + 'vid', '', '', '', '', '', '', '', 'subset' + ] + with open( + osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[ + ~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[ + text_annotations.video_id.isin(used_videos_ids)] + # remove a single dataset annotation mistake in video: T6bNPuKV-wY + used_text_annotations = used_text_annotations[ + used_text_annotations['instance_id'] != '1 (copy)'] + # convert data-frame to list of tuples: + used_text_annotations = list( + used_text_annotations.to_records(index=False)) + text_annotations_by_frame = [] + mask_annotations_dir = osp.join( + root_path, 'text_annotations/annotation_with_instances') + for video_id, instance_id, text_query in tqdm( + used_text_annotations): + frame_annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + instance_id = int(instance_id) + for p in frame_annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + if instance_id in instances: + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + frame_idx = int(p.split('/')[-1].split('.')[0]) + text_query = text_query.lower( + ) # lower the text query prior to augmentation & tokenization + text_annotations_by_frame.append( + (text_query, video_id, frame_idx, instance_id)) + with open(saved_annotations_file_path, 'w') as f: + json.dump(text_annotations_by_frame, f) + if distributed: + dist.barrier() + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + + +class A2dSentencesTransforms: + + def __init__(self, subset_type, horizontal_flip_augmentations, + resize_and_crop_augmentations, train_short_size, + train_max_size, eval_short_size, eval_max_size, **kwargs): + self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations + normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + scales = [ + train_short_size + ] # no more scales for now due to GPU memory constraints. might be changed later + transforms = [] + if resize_and_crop_augmentations: + if subset_type == 'train': + transforms.append( + T.RandomResize(scales, max_size=train_max_size)) + elif subset_type == 'test': + transforms.append( + T.RandomResize([eval_short_size], max_size=eval_max_size)), + transforms.extend([T.ToTensor(), normalize]) + self.size_transforms = T.Compose(transforms) + + def __call__(self, source_frames, targets, text_query): + if self.h_flip_augmentation and torch.rand(1) > 0.5: + source_frames = [F.hflip(f) for f in source_frames] + targets[len(targets) // 2]['masks'] = F.hflip( + targets[len(targets) // 2]['masks']) + # Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: + text_query = text_query.replace('left', '@').replace( + 'right', 'left').replace('@', 'right') + source_frames, targets = list( + zip(*[ + self.size_transforms(f, t) + for f, t in zip(source_frames, targets) + ])) + source_frames = torch.stack(source_frames) # [T, 3, H, W] + return source_frames, targets, text_query + + +class Collator: + + def __call__(self, batch): + samples, targets, text_queries = list(zip(*batch)) + samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] + # convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch + targets = list(zip(*targets)) + batch_dict = { + 'samples': samples, + 'targets': targets, + 'text_queries': text_queries + } + return batch_dict + + +def get_text_annotations_gt(root_path, subset): + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] + with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[text_annotations.video_id.isin( + used_videos_ids)] + # convert data-frame to list of tuples: + used_text_annotations = list(used_text_annotations.to_records(index=False)) + return used_text_annotations + + +def create_a2d_sentences_ground_truth_test_annotations(dataset_path, + subset_type, + mask_annotations_dir, + output_path): + text_annotations = get_text_annotations_gt(dataset_path, subset_type) + + # Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly + # expected by pycocotools as it is the convention of the original coco dataset annotations. + + categories_dict = [{ + 'id': 1, + 'name': 'dummy_class' + }] # dummy class, as categories are not used/predicted in RVOS + + images_dict = [] + annotations_dict = [] + images_set = set() + instance_id_counter = 1 + for annot in tqdm(text_annotations): + video_id, instance_id, text_query = annot + annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + for p in annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + try: + instance_idx = instances.index(int(instance_id)) + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + except ValueError: + continue # instance_id does not appear in current frame + mask = f['reMask'][instance_idx] if len( + instances) > 1 else np.array(f['reMask']) + mask = mask.transpose() + + frame_idx = int(p.split('/')[-1].split('.')[0]) + image_id = get_image_id(video_id, frame_idx, instance_id) + assert image_id not in images_set, f'error: image id: {image_id} appeared twice' + images_set.add(image_id) + images_dict.append({ + 'id': image_id, + 'height': mask.shape[0], + 'width': mask.shape[1] + }) + + mask_rle = encode(mask) + mask_rle['counts'] = mask_rle['counts'].decode('ascii') + mask_area = float(area(mask_rle)) + bbox = f['reBBox'][:, instance_idx] if len( + instances) > 1 else np.array( + f['reBBox']).squeeze() # x1y1x2y2 form + bbox_xywh = [ + bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] + ] + instance_annot = { + 'id': instance_id_counter, + 'image_id': image_id, + 'category_id': + 1, # dummy class, as categories are not used/predicted in ref-vos + 'segmentation': mask_rle, + 'area': mask_area, + 'bbox': bbox_xywh, + 'iscrowd': 0, + } + annotations_dict.append(instance_annot) + instance_id_counter += 1 + dataset_dict = { + 'categories': categories_dict, + 'images': images_dict, + 'annotations': annotations_dict + } + with open(output_path, 'w') as f: + json.dump(dataset_dict, f) diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py new file mode 100644 index 00000000..a5067b1b --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py @@ -0,0 +1,294 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target['size'] = torch.tensor([h, w]) + + fields = ['labels', 'area', 'iscrowd'] + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target['boxes'] = boxes + + if 'masks' in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int( + round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) + for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + h, w = size + target['size'] = torch.tensor([h, w]) + + if 'masks' in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target['size'] = torch.tensor(padded_image.size[::-1]) + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, + (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if 'boxes' in target: + boxes = target['boxes'] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target['boxes'] = boxes + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index c7aa7682..7a46b325 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -82,7 +82,7 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, dataset_name: str, namespace: str, version: str) -> list: """ - List all of objects for specific dataset. + List all objects for specific dataset. Args: hub_api (class HubApi): HubApi instance. diff --git a/modelscope/msdatasets/utils/delete_utils.py b/modelscope/msdatasets/utils/delete_utils.py new file mode 100644 index 00000000..a5a6f53f --- /dev/null +++ b/modelscope/msdatasets/utils/delete_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.hub.api import HubApi + + +class DatasetDeleteManager(object): + + def __init__(self, dataset_name: str, namespace: str, version: str): + self.api = HubApi() + self.dataset_name = dataset_name + self.namespace = namespace + self.version = version + + def delete(self, object_name: str) -> str: + + # single object + if not object_name.endswith('/'): + resp_msg = self.api.delete_oss_dataset_object( + object_name=object_name, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.version) + else: + # multiple objects + object_name = object_name.strip('/') + resp_msg = self.api.delete_oss_dataset_dir( + object_name=object_name, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.version) + + return resp_msg diff --git a/modelscope/msdatasets/utils/download_utils.py b/modelscope/msdatasets/utils/download_utils.py index b1c7a5ab..ebe9b8f5 100644 --- a/modelscope/msdatasets/utils/download_utils.py +++ b/modelscope/msdatasets/utils/download_utils.py @@ -27,7 +27,11 @@ class DatasetDownloadManager(DownloadManager): oss_config = api.get_dataset_access_config(self._dataset_name, self._namespace, self._version) - self.oss_utilities = OssUtilities(oss_config) + self.oss_utilities = OssUtilities( + oss_config=oss_config, + dataset_name=self._dataset_name, + namespace=self._namespace, + revision=self._version) def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str: diff --git a/modelscope/msdatasets/utils/oss_utils.py b/modelscope/msdatasets/utils/oss_utils.py index d7d61e89..e27ff8c4 100644 --- a/modelscope/msdatasets/utils/oss_utils.py +++ b/modelscope/msdatasets/utils/oss_utils.py @@ -6,19 +6,28 @@ import os import oss2 from datasets.utils.file_utils import hash_url_to_filename +from modelscope.hub.api import HubApi +from modelscope.utils.constant import UploadMode +from modelscope.utils.logger import get_logger + +logger = get_logger() + +ACCESS_ID = 'AccessId' +ACCESS_SECRET = 'AccessSecret' +SECURITY_TOKEN = 'SecurityToken' +BUCKET = 'Bucket' +BACK_DIR = 'BackupDir' +DIR = 'Dir' + class OssUtilities: - def __init__(self, oss_config): - self.key = oss_config['AccessId'] - self.secret = oss_config['AccessSecret'] - self.token = oss_config['SecurityToken'] - self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" - self.bucket_name = oss_config['Bucket'] - auth = oss2.StsAuth(self.key, self.secret, self.token) - self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) - self.oss_dir = oss_config['Dir'] - self.oss_backup_dir = oss_config['BackupDir'] + def __init__(self, oss_config, dataset_name, namespace, revision): + self._do_init(oss_config=oss_config) + + self.dataset_name = dataset_name + self.namespace = namespace + self.revision = revision self.upload_resumable_tmp_store = '/tmp/modelscope/tmp_dataset' self.upload_multipart_threshold = 50 * 1024 * 1024 @@ -26,6 +35,28 @@ class OssUtilities: self.upload_num_threads = 4 self.upload_max_retries = 3 + self.api = HubApi() + + def _do_init(self, oss_config): + self.key = oss_config[ACCESS_ID] + self.secret = oss_config[ACCESS_SECRET] + self.token = oss_config[SECURITY_TOKEN] + self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" + self.bucket_name = oss_config[BUCKET] + auth = oss2.StsAuth(self.key, self.secret, self.token) + self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + self.oss_dir = oss_config[DIR] + self.oss_backup_dir = oss_config[BACK_DIR] + + def _reload_sts(self): + cookies = self.api.check_local_cookies(use_cookies=True) + oss_config_refresh = self.api.get_dataset_access_config_session( + cookies=cookies, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.revision) + self._do_init(oss_config_refresh) + @staticmethod def _percentage(consumed_bytes, total_bytes): if total_bytes: @@ -51,7 +82,8 @@ class OssUtilities: return local_path def upload(self, oss_object_name: str, local_file_path: str, - indicate_individual_progress: bool) -> str: + indicate_individual_progress: bool, + upload_mode: UploadMode) -> str: retry_count = 0 object_key = os.path.join(self.oss_dir, oss_object_name) resumable_store = oss2.ResumableStore( @@ -64,6 +96,13 @@ class OssUtilities: while True: try: retry_count += 1 + exist = self.bucket.object_exists(object_key) + if upload_mode == UploadMode.APPEND and exist: + logger.info( + f'Skip {oss_object_name} in case of {upload_mode.value} mode.' + ) + break + oss2.resumable_upload( self.bucket, object_key, @@ -74,7 +113,9 @@ class OssUtilities: progress_callback=progress_callback, num_threads=self.upload_num_threads) break - except Exception: + except Exception as e: + if e.__getattribute__('status') == 403: + self._reload_sts() if retry_count >= self.upload_max_retries: raise diff --git a/modelscope/msdatasets/utils/upload_utils.py b/modelscope/msdatasets/utils/upload_utils.py index 2b4422b2..bbdcd9e9 100644 --- a/modelscope/msdatasets/utils/upload_utils.py +++ b/modelscope/msdatasets/utils/upload_utils.py @@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool as ThreadPool from tqdm import tqdm +from modelscope.utils.constant import UploadMode from .oss_utils import OssUtilities @@ -13,38 +14,45 @@ class DatasetUploadManager(object): def __init__(self, dataset_name: str, namespace: str, version: str): from modelscope.hub.api import HubApi _hub_api = HubApi() - _cookies = _hub_api.check_cookies_upload_data(use_cookies=True) + _cookies = _hub_api.check_local_cookies(use_cookies=True) _oss_config = _hub_api.get_dataset_access_config_session( cookies=_cookies, dataset_name=dataset_name, namespace=namespace, revision=version) - self.oss_utilities = OssUtilities(_oss_config) + self.oss_utilities = OssUtilities( + oss_config=_oss_config, + dataset_name=dataset_name, + namespace=namespace, + revision=version) - def upload(self, object_name: str, local_file_path: str) -> str: + def upload(self, object_name: str, local_file_path: str, + upload_mode: UploadMode) -> str: object_key = self.oss_utilities.upload( oss_object_name=object_name, local_file_path=local_file_path, - indicate_individual_progress=True) + indicate_individual_progress=True, + upload_mode=upload_mode) return object_key def upload_dir(self, object_dir_name: str, local_dir_path: str, num_processes: int, chunksize: int, - filter_hidden_files: bool) -> int: + filter_hidden_files: bool, upload_mode: UploadMode) -> int: def run_upload(args): self.oss_utilities.upload( oss_object_name=args[0], local_file_path=args[1], - indicate_individual_progress=False) + indicate_individual_progress=False, + upload_mode=upload_mode) files_list = [] for root, dirs, files in os.walk(local_dir_path): for file_name in files: if filter_hidden_files and file_name.startswith('.'): continue - # Concatenate directory name and relative path into a oss object key. e.g., train/001/1_1230.png + # Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png object_name = os.path.join( object_dir_name, root.replace(local_dir_path, '', 1).strip('/'), file_name) diff --git a/modelscope/outputs/nlp/model_outputs.py b/modelscope/outputs/nlp/model_outputs.py index dcb37145..46267007 100644 --- a/modelscope/outputs/nlp/model_outputs.py +++ b/modelscope/outputs/nlp/model_outputs.py @@ -541,3 +541,50 @@ class Seq2SeqLMOutput(ModelOutputBase): encoder_last_hidden_state: Optional[Tensor] = None encoder_hidden_states: Optional[Tuple[Tensor]] = None encoder_attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class TextGenerationModelOutput(ModelOutputBase): + """The output class for text generation models. + + Args: + logits (`Tensor`): The logits output of the model. loss (`Tensor`, + *optional*) The loss of the model, available when training. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + + +@dataclass +class TokenGeneratorOutput(ModelOutputBase): + """ + The output class for generate method of text generation models. + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` + is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` + is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, + sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` + is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. + """ + + sequences: Tensor = None + scores: Optional[Tuple[Tensor]] = None + attentions: Optional[Tuple[Tuple[Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[Tensor]]] = None diff --git a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py index d264b386..cfbf2607 100644 --- a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py +++ b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py @@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline): * text_border_height_per_query, 0, 0)) W, H = vid_frame.size draw = ImageDraw.Draw(vid_frame) - font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) + + if self.model.cfg.pipeline.output_font: + font = ImageFont.truetype( + font=self.model.cfg.pipeline.output_font, + size=self.model.cfg.pipeline.output_font_size) + else: + font = ImageFont.load_default() for i, (text_query, color) in enumerate( zip(self.text_queries, colors), start=1): w, h = draw.textsize(text_query, font=font) diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 2d5b664f..fdde5f25 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -104,6 +104,10 @@ class TextGenerationPipeline(Pipeline): tokenizer = self.preprocessor.tokenizer return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) + def sentence_piece(self, inputs) -> str: + tokenizer = self.preprocessor.tokenizer + return tokenizer.decode(inputs.tolist()) + def roberta(self, inputs) -> str: tokenizer = self.preprocessor.tokenizer decoded = tokenizer.decode(inputs.tolist()) @@ -121,7 +125,7 @@ class TextGenerationPipeline(Pipeline): Dict[str, str]: the prediction results """ inputs = inputs['sequences'] - if isinstance(inputs, list): + if isinstance(inputs, list) or len(inputs.shape) > 1: inputs = inputs[0] decoded = getattr(self, self.postprocessor)(inputs) text = self._remove_space_between_chinese_chars(decoded) diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py index c36f0dfc..75bc538d 100644 --- a/modelscope/pipelines/nlp/token_classification_pipeline.py +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach, __all__ = ['TokenClassificationPipeline'] +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.token_classification) @PIPELINES.register_module( Tasks.token_classification, module_name=Pipelines.part_of_speech) @PIPELINES.register_module( @@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline): str) else model if preprocessor is None: - preprocessor = Model.from_pretrained( + preprocessor = Preprocessor.from_pretrained( model.model_dir, sequence_length=kwargs.pop('sequence_length', 128)) model.eval() diff --git a/modelscope/preprocessors/base.py b/modelscope/preprocessors/base.py index db14ba47..be62ebb4 100644 --- a/modelscope/preprocessors/base.py +++ b/modelscope/preprocessors/base.py @@ -147,8 +147,50 @@ class Preprocessor(ABC): cfg_dict: Config = None, preprocessor_mode=ModeKeys.INFERENCE, **kwargs): - """ Instantiate a model from local directory or remote model repo. Note + """Instantiate a preprocessor from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id used to load the preprocessor out. + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional config. If provided, it will replace + the config read out of the `model_name_or_path` + preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, + or `inference`. Default value `inference`. + The preprocessor field in the config may contain two sub preprocessors: + >>> { + >>> "train": { + >>> "type": "some-train-preprocessor" + >>> }, + >>> "val": { + >>> "type": "some-eval-preprocessor" + >>> } + >>> } + In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor + will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class + will be assigned in all the modes. + Or just one: + >>> { + >>> "type": "some-train-preprocessor" + >>> } + In this scenario, the sole preprocessor will be loaded in all the modes, + and the `mode` field in the preprocessor class will be assigned. + + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. + This is useful when the preprocessor does not have a `type` field and the task to be used is not + equal to the task of which the model is saved. + Other kwargs will be directly fed into the preprocessor, to replace the default configs. + + Returns: + The preprocessor instance. + + Examples: + >>> from modelscope.preprocessors import Preprocessor + >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') + """ if not os.path.exists(model_name_or_path): model_dir = snapshot_download( diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 256c5243..557b469a 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -157,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): def image_open(self, path: str) -> Tuple[Image.Image, int]: if path not in self._image_map: index = len(self._image_map) - self._image_map[path] = (Image.open(path), index) + self._image_map[path] = (load_image(path), index) return self._image_map[path] def __call__( diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index d914489c..37fdcc12 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -9,7 +9,8 @@ if TYPE_CHECKING: from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, ImagePortraitEnhancementTrainer, - MovieSceneSegmentationTrainer, ImageInpaintingTrainer) + MovieSceneSegmentationTrainer, ImageInpaintingTrainer, + ReferringVideoObjectSegmentationTrainer) from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer, TextRankingTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index d09fd75c..32c38de2 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer from .image_inpainting_trainer import ImageInpaintingTrainer + from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer else: _import_structure = { @@ -17,7 +18,9 @@ else: 'image_portrait_enhancement_trainer': ['ImagePortraitEnhancementTrainer'], 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], - 'image_inpainting_trainer': ['ImageInpaintingTrainer'] + 'image_inpainting_trainer': ['ImageInpaintingTrainer'], + 'referring_video_object_segmentation_trainer': + ['ReferringVideoObjectSegmentationTrainer'] } import sys diff --git a/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..c15df3a5 --- /dev/null +++ b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import torch + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys + + +@TRAINERS.register_module( + module_name=Trainers.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model.set_postprocessor(self.cfg.dataset.name) + self.train_data_collator = self.train_dataset.collator + self.eval_data_collator = self.eval_dataset.collator + + device_name = kwargs.get('device', 'gpu') + self.model.set_device(self.device, device_name) + + def train(self, *args, **kwargs): + self.model.criterion.train() + super().train(*args, **kwargs) + + def evaluate(self, checkpoint_path=None): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + if self.eval_dataset is None: + self.eval_dataloader = self.get_eval_data_loader() + else: + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.eval_data_collator, + **self.cfg.evaluation.get('dataloader', {})) + self.data_loader = self.eval_dataloader + + from modelscope.metrics import build_metric + ann_file = self.eval_dataset.ann_file + metric_classes = [] + for metric in self.metrics: + metric.update({'ann_file': ann_file}) + metric_classes.append(build_metric(metric)) + + for m in metric_classes: + m.trainer = self + + metric_values = self.evaluation_loop(self.eval_dataloader, + metric_classes) + + self._metric_values = metric_values + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 9b86d5b5..89aa39ba 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -101,8 +101,9 @@ class CheckpointHook(Hook): model = trainer.model.module else: model = trainer.model - meta = load_checkpoint(filename, model, trainer.optimizer, - trainer.lr_scheduler) + meta = load_checkpoint(filename, model, + getattr(trainer, 'optimizer', None), + getattr(trainer, 'lr_scheduler', None)) trainer._epoch = meta.get('epoch', trainer._epoch) trainer._iter = meta.get('iter', trainer._iter) trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) @@ -111,7 +112,7 @@ class CheckpointHook(Hook): # hook: Hook key = f'{hook.__class__}-{i}' if key in meta and hasattr(hook, 'load_state_dict'): - hook.load_state_dict(meta[key]) + hook.load_state_dict(meta.get(key, {})) else: trainer.logger.warn( f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' @@ -123,7 +124,7 @@ class CheckpointHook(Hook): f'The modelscope version of loaded checkpoint does not match the runtime version. ' f'The saved version: {version}, runtime version: {__version__}' ) - trainer.logger.warn( + trainer.logger.info( f'Checkpoint {filename} saving time: {meta.get("time")}') return meta @@ -171,12 +172,17 @@ class CheckpointHook(Hook): else: model = trainer.model + config = trainer.cfg.to_dict() + # override pipeline by tasks name after finetune done, + # avoid case like fill mask pipeline with a text cls task + config['pipeline'] = {'type': config['task']} + if hasattr(model, 'save_pretrained'): model.save_pretrained( output_dir, ModelFile.TORCH_MODEL_BIN_FILE, save_function=save_checkpoint, - config=trainer.cfg.to_dict(), + config=config, with_meta=False) def after_train_iter(self, trainer): diff --git a/modelscope/trainers/multi_modal/__init__.py b/modelscope/trainers/multi_modal/__init__.py index 89b7e1bc..448f23a3 100644 --- a/modelscope/trainers/multi_modal/__init__.py +++ b/modelscope/trainers/multi_modal/__init__.py @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .clip import CLIPTrainer + from .team import TEAMImgClsTrainer else: - _import_structure = {'clip': ['CLIPTrainer']} + _import_structure = { + 'clip': ['CLIPTrainer'], + 'team': ['TEAMImgClsTrainer'] + } import sys diff --git a/modelscope/trainers/multi_modal/team/__init__.py b/modelscope/trainers/multi_modal/team/__init__.py new file mode 100644 index 00000000..b48fcc7e --- /dev/null +++ b/modelscope/trainers/multi_modal/team/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .team_trainer import TEAMImgClsTrainer diff --git a/modelscope/trainers/multi_modal/team/team_trainer.py b/modelscope/trainers/multi_modal/team/team_trainer.py new file mode 100644 index 00000000..7c557416 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from collections import OrderedDict +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from sklearn.metrics import confusion_matrix +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + get_optimizer, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.image_classification_team) +class TEAMImgClsTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, model: str, device_id: int, + data_collator: Callable, train_dataset: Dataset, + val_dataset: Dataset, *args, **kwargs): + super().__init__(cfg_file) + + self.cfg = Config.from_file(cfg_file) + team_model = Model.from_pretrained(model) + image_model = team_model.model.image_model.vision_transformer + classification_model = nn.Sequential( + OrderedDict([('encoder', image_model), + ('classifier', + nn.Linear(768, self.cfg.dataset.class_num))])) + self.model = classification_model + + for pname, param in self.model.named_parameters(): + if 'encoder' in pname: + param.requires_grad = False + + self.device_id = device_id + self.total_epoch = self.cfg.train.epoch + self.train_batch_size = self.cfg.train.batch_size + self.val_batch_size = self.cfg.evaluation.batch_size + self.ckpt_dir = self.cfg.train.ckpt_dir + + self.collate_fn = data_collator + self.train_dataset = train_dataset + self.val_dataset = val_dataset + + self.criterion = nn.CrossEntropyLoss().to(self.device_id) + + def train(self, *args, **kwargs): + self.model.train() + self.model.to(self.device_id) + + optimizer = get_optimizer(self.model) + + for epoch in range(self.total_epoch): + train_params = { + 'pin_memory': True, + 'collate_fn': self.collate_fn, + 'batch_size': self.train_batch_size, + 'shuffle': True, + 'drop_last': True, + 'num_workers': 8 + } + + train_loader = DataLoader(self.train_dataset, **train_params) + + for batch_idx, data in enumerate(train_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + loss = self.criterion(pred_logits, label_tensor) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + logger.info( + 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( + epoch, batch_idx, len(train_loader), loss.item())) + + os.makedirs(self.ckpt_dir, exist_ok=True) + torch.save(self.model.state_dict(), + '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) + self.evaluate() + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + checkpoint_params = torch.load(checkpoint_path, 'cpu') + self.model.load_state_dict(checkpoint_params) + self.model.eval() + self.model.to(self.device_id) + + val_params = { + 'collate_fn': self.collate_fn, + 'batch_size': self.val_batch_size, + 'shuffle': False, + 'drop_last': False, + 'num_workers': 8 + } + val_loader = DataLoader(self.val_dataset, **val_params) + + tp_cnt, processed_cnt = 0, 0 + all_pred_labels, all_gt_labels = [], [] + with torch.no_grad(): + for batch_idx, data in enumerate(val_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + pred_labels = torch.max(pred_logits, dim=1)[1] + tp_cnt += torch.sum(pred_labels == label_tensor).item() + processed_cnt += img_tensor.shape[0] + logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) + + all_pred_labels.extend(pred_labels.tolist()) + all_gt_labels.extend(label_tensor.tolist()) + conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) + acc_mean_per_class = np.mean(conf_mat.diagonal() + / conf_mat.sum(axis=1)) + logger.info( + 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) diff --git a/modelscope/trainers/multi_modal/team/team_trainer_utils.py b/modelscope/trainers/multi_modal/team/team_trainer_utils.py new file mode 100644 index 00000000..ff1a4fd6 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.optim import AdamW + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +train_transforms = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) +val_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) + + +def train_mapping(examples): + examples['pixel_values'] = [ + train_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def val_mapping(examples): + examples['pixel_values'] = [ + val_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def collate_fn(examples): + images = [] + labels = [] + for example in examples: + images.append((example['pixel_values'])) + labels.append(example['labels']) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {'pixel_values': pixel_values, 'labels': labels} + + +def get_params_groups(ddp_model, lr): + large_lr_params = [] + small_lr_params = [] + for name, param in ddp_model.named_parameters(): + if not param.requires_grad: + continue + + if 'encoder' in name: + small_lr_params.append(param) + elif 'classifier' in name: + large_lr_params.append(param) + else: + logger.info('skip param: {}'.format(name)) + + params_groups = [{ + 'params': small_lr_params, + 'lr': lr / 10.0 + }, { + 'params': large_lr_params, + 'lr': lr + }] + return params_groups + + +def get_optimizer(ddp_model): + lr_init = 1e-3 + betas = [0.9, 0.999] + weight_decay = 0.02 + params_groups = get_params_groups(ddp_model, lr=lr_init) + return AdamW( + params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index a19e7c7b..a92a3706 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): break for metric_name in self.metrics: - metric_values[metric_name] = np.average( - [m[metric_name] for m in metric_values.values()]) + all_metrics = [m[metric_name] for m in metric_values.values()] + for key in all_metrics[0].keys(): + metric_values[key] = np.average( + [metric[key] for metric in all_metrics]) return metric_values diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index f660a55a..e1fd7522 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): return dataset def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): - return build_optimizer(self.model, cfg=cfg, default_args=default_args) + try: + return build_optimizer( + self.model, cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): - return build_lr_scheduler(cfg=cfg, default_args=default_args) + try: + return build_lr_scheduler(cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index d6187b5f..6e4e7a19 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -62,7 +62,10 @@ def single_gpu_test(trainer, if 'nsentences' in data: batch_size = data['nsentences'] else: - batch_size = len(next(iter(data.values()))) + try: + batch_size = len(next(iter(data.values()))) + except Exception: + batch_size = data_loader.batch_size else: batch_size = len(data) for _ in range(batch_size): diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 2a7520f2..5acaa411 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -134,9 +134,7 @@ def load_checkpoint(filename, state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ 'state_dict'] model.load_state_dict(state_dict) - - if 'meta' in checkpoint: - return checkpoint.get('meta', {}) + return checkpoint.get('meta', {}) def save_pretrained(model, diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6394ad8a..2729b75a 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -238,6 +238,15 @@ class DownloadMode(enum.Enum): FORCE_REDOWNLOAD = 'force_redownload' +class UploadMode(enum.Enum): + """ How to upload object to remote. + """ + # Upload all objects from local, existing remote objects may be overwritten. (Default) + OVERWRITE = 'overwrite' + # Upload local objects in append mode, skipping all existing remote objects. + APPEND = 'append' + + class DatasetFormations(enum.Enum): """ How a dataset is organized and interpreted """ diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index 828b97f8..5b6e957d 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -87,21 +87,23 @@ class HubOperationTest(unittest.TestCase): assert mdtime1 == mdtime2 def test_download_public_without_login(self): - self.prepare_case() - rmtree(ModelScopeConfig.path_credential) - snapshot_path = snapshot_download( - model_id=self.model_id, revision=self.revision) - downloaded_file_path = os.path.join(snapshot_path, - download_model_file_name) - assert os.path.exists(downloaded_file_path) - temporary_dir = tempfile.mkdtemp() - downloaded_file = model_file_download( - model_id=self.model_id, - file_path=download_model_file_name, - revision=self.revision, - cache_dir=temporary_dir) - assert os.path.exists(downloaded_file) - self.api.login(TEST_ACCESS_TOKEN1) + try: + self.prepare_case() + rmtree(ModelScopeConfig.path_credential) + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + temporary_dir = tempfile.mkdtemp() + downloaded_file = model_file_download( + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision, + cache_dir=temporary_dir) + assert os.path.exists(downloaded_file) + finally: + self.api.login(TEST_ACCESS_TOKEN1) def test_snapshot_delete_download_cache_file(self): self.prepare_case() diff --git a/tests/metrics/test_text_classification_metrics.py b/tests/metrics/test_text_classification_metrics.py new file mode 100644 index 00000000..d0a4cee1 --- /dev/null +++ b/tests/metrics/test_text_classification_metrics.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.metrics.sequence_classification_metric import \ + SequenceClassificationMetric +from modelscope.utils.test_utils import test_level + + +class TestTextClsMetrics(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_value(self): + metric = SequenceClassificationMetric() + outputs = { + 'logits': + np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], + [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], + [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) + } + inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} + metric.add(outputs, inputs) + ret = metric.evaluate() + self.assertTrue(np.isclose(ret['f1'], 0.5)) + self.assertTrue(np.isclose(ret['accuracy'], 0.5)) + print(ret) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_dataset_delete.py b/tests/msdatasets/test_dataset_delete.py new file mode 100644 index 00000000..8b3c2426 --- /dev/null +++ b/tests/msdatasets/test_dataset_delete.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.msdatasets import MsDataset +from modelscope.utils import logger as logging +from modelscope.utils.test_utils import test_level + +logger = logging.get_logger(__name__) + +KEY_EXTRACTED = 'extracted' +EXPECTED_MSG = 'success' + + +class DatasetDeleteTest(unittest.TestCase): + + def setUp(self): + self.old_dir = os.getcwd() + self.dataset_name = 'small_coco_for_test' + self.dataset_file_name = self.dataset_name + self.prepared_dataset_name = 'pets_small' + self.token = os.getenv('TEST_UPLOAD_MS_TOKEN') + error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN' + self.assertIsNotNone(self.token, msg=error_msg) + from modelscope.hub.api import HubApi + from modelscope.hub.api import ModelScopeConfig + self.api = HubApi() + self.api.login(self.token) + + # get user info + self.namespace, _ = ModelScopeConfig.get_user_info() + + self.temp_dir = tempfile.mkdtemp() + self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name) + if not os.path.exists(self.test_work_dir): + os.makedirs(self.test_work_dir) + + def tearDown(self): + os.chdir(self.old_dir) + shutil.rmtree(self.temp_dir, ignore_errors=True) + logger.info( + f'Temporary directory {self.temp_dir} successfully removed!') + + @staticmethod + def get_raw_downloaded_file_path(extracted_path): + raw_downloaded_file_path = '' + raw_data_dir = os.path.abspath( + os.path.join(extracted_path, '../../..')) + for root, dirs, files in os.walk(raw_data_dir): + if KEY_EXTRACTED in dirs: + for file in files: + curr_file_path = os.path.join(root, file) + if zipfile.is_zipfile(curr_file_path): + raw_downloaded_file_path = curr_file_path + return raw_downloaded_file_path + + def upload_test_file(self): + # Get the prepared data from hub, using default modelscope namespace + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_res = ms_ds_train._hf_ds.config_kwargs + extracted_path = config_res.get('split_config').get('train') + raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path) + + object_name = self.dataset_file_name + '_for_del.zip' + MsDataset.upload( + object_name=object_name, + local_file_path=raw_zipfile_path, + dataset_name=self.dataset_name, + namespace=self.namespace) + + return object_name + + def upload_test_dir(self): + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_train = ms_ds_train._hf_ds.config_kwargs + extracted_path_train = config_train.get('split_config').get('train') + + object_name = 'train_for_del' + MsDataset.upload( + object_name=object_name, + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/train'), + dataset_name=self.dataset_name, + namespace=self.namespace) + + return object_name + '/' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_delete_object(self): + + # upload prepared data + file_name = self.upload_test_file() + dir_name = self.upload_test_dir() + + # delete object + del_file_msg = MsDataset.delete( + object_name=file_name, + dataset_name=self.dataset_name, + namespace=self.namespace) + del_dir_msg = MsDataset.delete( + object_name=dir_name, + dataset_name=self.dataset_name, + namespace=self.namespace) + + assert all([del_file_msg == EXPECTED_MSG, del_dir_msg == EXPECTED_MSG]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 57dcb0c3..6be70468 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -243,6 +243,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_with_text_to_image_synthesis_with_name(self): model = 'damo/ofa_text-to-image-synthesis_coco_large_en' ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 example = {'text': 'a bear in the water.'} result = ofa_pipe(example) result[OutputKeys.OUTPUT_IMG].save('result.png') @@ -253,6 +254,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained( 'damo/ofa_text-to-image-synthesis_coco_large_en') ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 example = {'text': 'a bear in the water.'} result = ofa_pipe(example) result[OutputKeys.OUTPUT_IMG].save('result.png') diff --git a/tests/pipelines/test_referring_video_object_segmentation.py b/tests/pipelines/test_referring_video_object_segmentation.py index 3e81d9c3..4d8206b3 100644 --- a/tests/pipelines/test_referring_video_object_segmentation.py +++ b/tests/pipelines/test_referring_video_object_segmentation.py @@ -14,7 +14,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, self.task = Tasks.referring_video_object_segmentation self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skip('skip since the model is set to private for now') def test_referring_video_object_segmentation(self): input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' text_queries = [ @@ -31,7 +31,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, else: raise ValueError('process error') - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skip('skip since the model is set to private for now') def test_referring_video_object_segmentation_with_default_task(self): input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' text_queries = [ diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index c97f347d..ddb77eeb 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -183,7 +183,7 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): task=Tasks.text_generation, model='langboat/bloom-1b4-zh') print(pipe('中国的首都是')) - @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_gpt_neo(self): pipe = pipeline( task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py index 43e1842d..c92b5568 100644 --- a/tests/pipelines/test_tinynas_detection.py +++ b/tests/pipelines/test_tinynas_detection.py @@ -20,16 +20,16 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_object_detection, model='damo/cv_tinynas_detection') result = tinynas_object_detection( 'data/test/images/image_detection.jpg') - print(result) + print('airdet', result) - @unittest.skip('will be enabled after damoyolo officially released') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_damoyolo(self): tinynas_object_detection = pipeline( Tasks.image_object_detection, model='damo/cv_tinynas_object-detection_damoyolo') result = tinynas_object_detection( 'data/test/images/image_detection.jpg') - print(result) + print('damoyolo', result) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): @@ -39,7 +39,8 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): def test_image_object_detection_auto_pipeline(self): test_image = 'data/test/images/image_detection.jpg' tinynas_object_detection = pipeline( - Tasks.image_object_detection, model='damo/cv_tinynas_detection') + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo') result = tinynas_object_detection(test_image) tinynas_object_detection.show_result(test_image, result, 'demo_ret.jpg') diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index ae780793..02dd9d2f 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): train_datasets = [] from datasets import DownloadConfig dc = DownloadConfig() - dc.local_files_only = True + dc.local_files_only = False for lang in langs: train_datasets.append( load_dataset('xnli', lang, split='train', download_config=dc)) diff --git a/tests/trainers/test_referring_video_object_segmentation_trainer.py b/tests/trainers/test_referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..7b03eb4d --- /dev/null +++ b/tests/trainers/test_referring_video_object_segmentation_trainer.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.movie_scene_segmentation import \ + MovieSceneSegmentationModel +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + dataset_name = 'referring_vos_toydata' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name=self.dataset_name, + split='train', + test_mode=False, + cfg=cfg.dataset) + + test_data_cfg = ConfigDict( + name=self.dataset_name, + split='test', + test_mode=True, + cfg=cfg.dataset) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + cfg=train_data_cfg.cfg, + namespace='damo', + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + cfg=test_data_cfg.cfg, + namespace='damo', + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + @unittest.skip('skip since the model is set to private for now') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skip('skip since the model is set to private for now') + def test_trainer_with_model_and_args(self): + + cache_path = snapshot_download(self.model_id) + model = MovieSceneSegmentationModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_team_transfer_trainer.py b/tests/trainers/test_team_transfer_trainer.py new file mode 100644 index 00000000..0f6b88bb --- /dev/null +++ b/tests/trainers/test_team_transfer_trainer.py @@ -0,0 +1,94 @@ +import os +import unittest + +import json +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + collate_fn, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +def train_worker(device_id): + model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' + ckpt_dir = './ckpt' + os.makedirs(ckpt_dir, exist_ok=True) + # Use epoch=1 for faster training here + cfg = Config({ + 'framework': 'pytorch', + 'task': 'multi-modal-similarity', + 'pipeline': { + 'type': 'multi-modal-similarity' + }, + 'model': { + 'type': 'team-multi-modal-similarity' + }, + 'dataset': { + 'name': 'Caltech101', + 'class_num': 101 + }, + 'preprocessor': {}, + 'train': { + 'epoch': 1, + 'batch_size': 32, + 'ckpt_dir': ckpt_dir + }, + 'evaluation': { + 'batch_size': 64 + } + }) + cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) + cfg.dump(cfg_file) + + train_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + train_dataset = train_dataset.with_transform(train_mapping) + val_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + val_dataset = val_dataset.with_transform(val_mapping) + + default_args = dict( + cfg_file=cfg_file, + model=model_id, + device_id=device_id, + data_collator=collate_fn, + train_dataset=train_dataset, + val_dataset=val_dataset) + + trainer = build_trainer( + name=Trainers.image_classification_team, default_args=default_args) + trainer.train() + trainer.evaluate() + + +class TEAMTransferTrainerTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + if torch.cuda.device_count() > 0: + train_worker(device_id=0) + else: + train_worker(device_id=-1) + logger.info('Training done') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 8aaa42a3..d9d56b60 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -119,7 +119,7 @@ class TestTrainerWithNlp(unittest.TestCase): checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) self.assertTrue(Metrics.accuracy in eval_results) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skip('skip for now before test is re-configured') def test_trainer_with_configured_datasets(self): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) @@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): trainer, 'trainer_continue_train', level='strict'): trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_evaluation(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + cache_path = snapshot_download(model_id) + model = SbertForSequenceClassification.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + print(trainer.evaluate(cache_path + '/pytorch_model.bin')) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_model_and_args(self): tmp_dir = tempfile.TemporaryDirectory().name if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) - model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' cache_path = snapshot_download(model_id) model = SbertForSequenceClassification.from_pretrained(cache_path) kwargs = dict(