@@ -176,7 +176,10 @@ class HubApi: | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
if revision: | |||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
else: | |||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' | |||
r = requests.get(path, cookies=cookies, headers=self.headers) | |||
handle_http_response(r, logger, cookies, model_id) | |||
@@ -447,8 +450,12 @@ class HubApi: | |||
Returns: | |||
List[dict]: Model file list. | |||
""" | |||
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( | |||
self.endpoint, model_id, revision, recursive) | |||
if revision: | |||
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( | |||
self.endpoint, model_id, revision, recursive) | |||
else: | |||
path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % ( | |||
self.endpoint, model_id, recursive) | |||
cookies = self._check_cookie(use_cookies) | |||
if root is not None: | |||
path = path + f'&Root={root}' | |||
@@ -499,13 +506,14 @@ class HubApi: | |||
shutil.rmtree(cache_dir) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
r = requests.get(datahub_url) | |||
cookies = ModelScopeConfig.get_cookies() | |||
r = requests.get(datahub_url, cookies=cookies) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
dataset_id = resp['Data']['Id'] | |||
dataset_type = resp['Data']['Type'] | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
r = requests.get(datahub_url, headers=self.headers) | |||
r = requests.get(datahub_url, cookies=cookies, headers=self.headers) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
file_list = resp['Data'] | |||
@@ -524,7 +532,7 @@ class HubApi: | |||
if extension in dataset_meta_format: | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
f'Revision={revision}&FilePath={file_path}' | |||
r = requests.get(datahub_url) | |||
r = requests.get(datahub_url, cookies=cookies) | |||
raise_for_http_status(r) | |||
local_path = os.path.join(cache_dir, file_path) | |||
if os.path.exists(local_path): | |||
@@ -569,9 +577,7 @@ class HubApi: | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
f'ststoken?Revision={revision}' | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
r = requests.get( | |||
url=datahub_url, cookies=cookies, headers=self.headers) | |||
r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers) | |||
resp = r.json() | |||
raise_on_error(resp) | |||
return resp['Data'] | |||
@@ -582,9 +588,6 @@ class HubApi: | |||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies: | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
resp = requests.get(url=url, cookies=cookies) | |||
resp = resp.json() | |||
raise_on_error(resp) | |||
@@ -593,17 +596,48 @@ class HubApi: | |||
def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | |||
r = requests.post(url, headers=self.headers) | |||
cookies = ModelScopeConfig.get_cookies() | |||
r = requests.post(url, cookies=cookies, headers=self.headers) | |||
raise_for_http_status(r) | |||
def delete_oss_dataset_object(self, object_name: str, dataset_name: str, | |||
namespace: str, revision: str) -> str: | |||
if not object_name or not dataset_name or not namespace or not revision: | |||
raise ValueError('Args cannot be empty!') | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' | |||
cookies = self.check_local_cookies(use_cookies=True) | |||
resp = requests.delete(url=url, cookies=cookies) | |||
resp = resp.json() | |||
raise_on_error(resp) | |||
resp = resp['Message'] | |||
return resp | |||
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, | |||
namespace: str, revision: str) -> str: | |||
if not object_name or not dataset_name or not namespace or not revision: | |||
raise ValueError('Args cannot be empty!') | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ | |||
f'&Revision={revision}' | |||
cookies = self.check_local_cookies(use_cookies=True) | |||
resp = requests.delete(url=url, cookies=cookies) | |||
resp = resp.json() | |||
raise_on_error(resp) | |||
resp = resp['Message'] | |||
return resp | |||
@staticmethod | |||
def datahub_remote_call(url): | |||
r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
cookies = ModelScopeConfig.get_cookies() | |||
r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
resp = r.json() | |||
datahub_raise_on_error(url, resp) | |||
return resp['Data'] | |||
def check_cookies_upload_data(self, use_cookies) -> CookieJar: | |||
def check_local_cookies(self, use_cookies) -> CookieJar: | |||
return self._check_cookie(use_cookies=use_cookies) | |||
@@ -63,6 +63,7 @@ def handle_http_post_error(response, url, request_body): | |||
except HTTPError as error: | |||
logger.error('Request %s with body: %s exception' % | |||
(url, request_body)) | |||
logger.error('Response details: %s' % response.content) | |||
raise error | |||
@@ -254,6 +254,7 @@ class Pipelines(object): | |||
translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
token_classification = 'token-classification' | |||
# audio tasks | |||
sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
@@ -305,6 +306,8 @@ class Trainers(object): | |||
face_detection_scrfd = 'face-detection-scrfd' | |||
card_detection_scrfd = 'card-detection-scrfd' | |||
image_inpainting = 'image-inpainting' | |||
referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
image_classification_team = 'image-classification-team' | |||
# nlp trainers | |||
bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
@@ -422,6 +425,8 @@ class Metrics(object): | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
# metric for ocr | |||
NED = 'ned' | |||
# metric for referring-video-object-segmentation task | |||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' | |||
class Optimizers(object): | |||
@@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||
from .accuracy_metric import AccuracyMetric | |||
from .bleu_metric import BleuMetric | |||
from .image_inpainting_metric import ImageInpaintingMetric | |||
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric | |||
else: | |||
_import_structure = { | |||
@@ -40,6 +41,8 @@ else: | |||
'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
'accuracy_metric': ['AccuracyMetric'], | |||
'bleu_metric': ['BleuMetric'], | |||
'referring_video_object_segmentation_metric': | |||
['ReferringVideoObjectSegmentationMetric'], | |||
} | |||
import sys | |||
@@ -43,6 +43,8 @@ task_default_metrics = { | |||
Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
Tasks.referring_video_object_segmentation: | |||
[Metrics.referring_video_object_segmentation_metric], | |||
} | |||
@@ -0,0 +1,108 @@ | |||
# Part of the implementation is borrowed and modified from MTTR, | |||
# publicly available at https://github.com/mttr2021/MTTR | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
from pycocotools.coco import COCO | |||
from pycocotools.cocoeval import COCOeval | |||
from pycocotools.mask import decode | |||
from tqdm import tqdm | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module( | |||
group_key=default_group, | |||
module_name=Metrics.referring_video_object_segmentation_metric) | |||
class ReferringVideoObjectSegmentationMetric(Metric): | |||
"""The metric computation class for movie scene segmentation classes. | |||
""" | |||
def __init__(self, | |||
ann_file=None, | |||
calculate_precision_and_iou_metrics=True): | |||
self.ann_file = ann_file | |||
self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics | |||
self.preds = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
preds_batch = outputs['pred'] | |||
self.preds.extend(preds_batch) | |||
def evaluate(self): | |||
coco_gt = COCO(self.ann_file) | |||
coco_pred = coco_gt.loadRes(self.preds) | |||
coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') | |||
coco_eval.params.useCats = 0 | |||
coco_eval.evaluate() | |||
coco_eval.accumulate() | |||
coco_eval.summarize() | |||
ap_labels = [ | |||
'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', | |||
'AP 0.5:0.95 M', 'AP 0.5:0.95 L' | |||
] | |||
ap_metrics = coco_eval.stats[:6] | |||
eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} | |||
if self.calculate_precision_and_iou_metrics: | |||
precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( | |||
coco_gt, coco_pred) | |||
eval_metrics.update({ | |||
f'P@{k}': m | |||
for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) | |||
}) | |||
eval_metrics.update({ | |||
'overall_iou': overall_iou, | |||
'mean_iou': mean_iou | |||
}) | |||
return eval_metrics | |||
def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): | |||
outputs = outputs.int() | |||
intersection = (outputs & labels).float().sum( | |||
(1, 2)) # Will be zero if Truth=0 or Prediction=0 | |||
union = (outputs | labels).float().sum( | |||
(1, 2)) # Will be zero if both are 0 | |||
iou = (intersection + EPS) / (union + EPS | |||
) # EPS is used to avoid division by zero | |||
return iou, intersection, union | |||
def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): | |||
print('evaluating precision@k & iou metrics...') | |||
counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} | |||
total_intersection_area = 0 | |||
total_union_area = 0 | |||
ious_list = [] | |||
for instance in tqdm(coco_gt.imgs.keys() | |||
): # each image_id contains exactly one instance | |||
gt_annot = coco_gt.imgToAnns[instance][0] | |||
gt_mask = decode(gt_annot['segmentation']) | |||
pred_annots = coco_pred.imgToAnns[instance] | |||
pred_annot = sorted( | |||
pred_annots, | |||
key=lambda a: a['score'])[-1] # choose pred with highest score | |||
pred_mask = decode(pred_annot['segmentation']) | |||
iou, intersection, union = compute_iou( | |||
torch.tensor(pred_mask).unsqueeze(0), | |||
torch.tensor(gt_mask).unsqueeze(0)) | |||
iou, intersection, union = iou.item(), intersection.item(), union.item( | |||
) | |||
for iou_threshold in counters_by_iou.keys(): | |||
if iou > iou_threshold: | |||
counters_by_iou[iou_threshold] += 1 | |||
total_intersection_area += intersection | |||
total_union_area += union | |||
ious_list.append(iou) | |||
num_samples = len(ious_list) | |||
precision_at_k = np.array(list(counters_by_iou.values())) / num_samples | |||
overall_iou = total_intersection_area / total_union_area | |||
mean_iou = np.mean(ious_list) | |||
return precision_at_k, overall_iou, mean_iou |
@@ -3,6 +3,7 @@ | |||
from typing import Dict | |||
import numpy as np | |||
from sklearn.metrics import accuracy_score, f1_score | |||
from modelscope.metainfo import Metrics | |||
from modelscope.outputs import OutputKeys | |||
@@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): | |||
preds = np.argmax(preds, axis=1) | |||
return { | |||
MetricKeys.ACCURACY: | |||
(preds == labels).astype(np.float32).mean().item() | |||
accuracy_score(labels, preds), | |||
MetricKeys.F1: | |||
f1_score( | |||
labels, | |||
preds, | |||
average='micro' if any([label > 1 | |||
for label in labels]) else None), | |||
} |
@@ -2,7 +2,7 @@ | |||
from typing import Dict, Iterable, List | |||
from nltk.translate.bleu_score import sentence_bleu | |||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu | |||
from rouge import Rouge | |||
from modelscope.metainfo import Metrics | |||
@@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): | |||
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) | |||
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | |||
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | |||
pred_split = tuple(pred.split(' ') for pred in self.preds) | |||
tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) | |||
bleu_1 = mean( | |||
sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) | |||
for pred, tgt in zip(pred_split, tgt_split)) | |||
bleu_4 = mean( | |||
sentence_bleu([tgt], pred) | |||
for pred, tgt in zip(pred_split, tgt_split)) | |||
pred_list = [each.strip().split(' ') for each in self.preds] | |||
tgt_list = [[each.strip().split(' ')] for each in self.tgts] | |||
bleu_1 = corpus_bleu( | |||
tgt_list, | |||
pred_list, | |||
weights=(1, 0, 0, 0), | |||
smoothing_function=SmoothingFunction().method3) | |||
bleu_4 = corpus_bleu( | |||
tgt_list, | |||
pred_list, | |||
smoothing_function=SmoothingFunction().method3) | |||
return { | |||
MetricKeys.ROUGE_1: rouge_1, | |||
MetricKeys.ROUGE_L: rouge_l, | |||
@@ -67,8 +67,28 @@ class Model(ABC): | |||
cfg_dict: Config = None, | |||
device: str = None, | |||
**kwargs): | |||
""" Instantiate a model from local directory or remote model repo. Note | |||
"""Instantiate a model from local directory or remote model repo. Note | |||
that when loading from remote, the model revision can be specified. | |||
Args: | |||
model_name_or_path(str): A model dir or a model id to be loaded | |||
revision(str, `optional`): The revision used when the model_name_or_path is | |||
a model id of the remote hub. default `master`. | |||
cfg_dict(Config, `optional`): An optional model config. If provided, it will replace | |||
the config read out of the `model_name_or_path` | |||
device(str, `optional`): The device to load the model. | |||
**kwargs: | |||
task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not | |||
equal to the model saved. | |||
For example, load a `backbone` into a `text-classification` model. | |||
Other kwargs will be directly fed into the `model` key, to replace the default configs. | |||
Returns: | |||
A model instance. | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') | |||
""" | |||
prefetched = kwargs.get('model_prefetched') | |||
if prefetched is not None: | |||
@@ -5,11 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .model import MovieSceneSegmentation | |||
from .model import ReferringVideoObjectSegmentation | |||
else: | |||
_import_structure = { | |||
'model': ['MovieSceneSegmentation'], | |||
'model': ['ReferringVideoObjectSegmentation'], | |||
} | |||
import sys | |||
@@ -1,4 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Part of the implementation is borrowed and modified from MTTR, | |||
# publicly available at https://github.com/mttr2021/MTTR | |||
import os.path as osp | |||
from typing import Any, Dict | |||
@@ -10,7 +12,9 @@ from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .utils import (MTTR, A2DSentencesPostProcess, ReferYoutubeVOSPostProcess, | |||
from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, | |||
ReferYoutubeVOSPostProcess, SetCriterion, | |||
flatten_temporal_batch_dims, | |||
nested_tensor_from_videos_list) | |||
logger = get_logger() | |||
@@ -35,16 +39,66 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||
params_dict = params_dict['model_state_dict'] | |||
self.model.load_state_dict(params_dict, strict=True) | |||
dataset_name = self.cfg.pipeline.dataset_name | |||
if dataset_name == 'a2d_sentences' or dataset_name == 'jhmdb_sentences': | |||
self.postprocessor = A2DSentencesPostProcess() | |||
elif dataset_name == 'ref_youtube_vos': | |||
self.postprocessor = ReferYoutubeVOSPostProcess() | |||
self.set_postprocessor(self.cfg.pipeline.dataset_name) | |||
self.set_criterion() | |||
def set_device(self, device, name): | |||
self.device = device | |||
self._device_name = name | |||
def set_postprocessor(self, dataset_name): | |||
if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: | |||
self.postprocessor = A2DSentencesPostProcess() # fine-tune | |||
elif 'ref_youtube_vos' in dataset_name: | |||
self.postprocessor = ReferYoutubeVOSPostProcess() # inference | |||
else: | |||
assert False, f'postprocessing for dataset: {dataset_name} is not supported' | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |||
return inputs | |||
def forward(self, inputs: Dict[str, Any]): | |||
samples = inputs['samples'] | |||
targets = inputs['targets'] | |||
text_queries = inputs['text_queries'] | |||
valid_indices = torch.tensor( | |||
[i for i, t in enumerate(targets) if None not in t]) | |||
targets = [targets[i] for i in valid_indices.tolist()] | |||
if self._device_name == 'gpu': | |||
samples = samples.to(self.device) | |||
valid_indices = valid_indices.to(self.device) | |||
if isinstance(text_queries, tuple): | |||
text_queries = list(text_queries) | |||
outputs = self.model(samples, valid_indices, text_queries) | |||
losses = -1 | |||
if self.training: | |||
loss_dict = self.criterion(outputs, targets) | |||
weight_dict = self.criterion.weight_dict | |||
losses = sum(loss_dict[k] * weight_dict[k] | |||
for k in loss_dict.keys() if k in weight_dict) | |||
predictions = [] | |||
if not self.training: | |||
outputs.pop('aux_outputs', None) | |||
outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||
processed_outputs = self.postprocessor( | |||
outputs, | |||
resized_padded_sample_size=samples.tensors.shape[-2:], | |||
resized_sample_sizes=[t['size'] for t in targets], | |||
orig_sample_sizes=[t['orig_size'] for t in targets]) | |||
image_ids = [t['image_id'] for t in targets] | |||
predictions = [] | |||
for p, image_id in zip(processed_outputs, image_ids): | |||
for s, m in zip(p['scores'], p['rle_masks']): | |||
predictions.append({ | |||
'image_id': image_id, | |||
'category_id': | |||
1, # dummy label, as categories are not predicted in ref-vos | |||
'segmentation': m, | |||
'score': s.item() | |||
}) | |||
re = dict(pred=predictions, loss=losses) | |||
return re | |||
def inference(self, **kwargs): | |||
window = kwargs['window'] | |||
@@ -63,3 +117,26 @@ class ReferringVideoObjectSegmentation(TorchModel): | |||
def postprocess(self, inputs: Dict[str, Any], **kwargs): | |||
return inputs | |||
def set_criterion(self): | |||
matcher = HungarianMatcher( | |||
cost_is_referred=self.cfg.matcher.set_cost_is_referred, | |||
cost_dice=self.cfg.matcher.set_cost_dice) | |||
weight_dict = { | |||
'loss_is_referred': self.cfg.loss.is_referred_loss_coef, | |||
'loss_dice': self.cfg.loss.dice_loss_coef, | |||
'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef | |||
} | |||
if self.cfg.loss.aux_loss: | |||
aux_weight_dict = {} | |||
for i in range(self.cfg.model.num_decoder_layers - 1): | |||
aux_weight_dict.update( | |||
{k + f'_{i}': v | |||
for k, v in weight_dict.items()}) | |||
weight_dict.update(aux_weight_dict) | |||
self.criterion = SetCriterion( | |||
matcher=matcher, | |||
weight_dict=weight_dict, | |||
eos_coef=self.cfg.loss.eos_coef) |
@@ -1,4 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .misc import nested_tensor_from_videos_list | |||
from .criterion import SetCriterion, flatten_temporal_batch_dims | |||
from .matcher import HungarianMatcher | |||
from .misc import interpolate, nested_tensor_from_videos_list | |||
from .mttr import MTTR | |||
from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess |
@@ -0,0 +1,198 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
import torch | |||
from torch import nn | |||
from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, | |||
nested_tensor_from_tensor_list) | |||
from .segmentation import dice_loss, sigmoid_focal_loss | |||
class SetCriterion(nn.Module): | |||
""" This class computes the loss for MTTR. | |||
The process happens in two steps: | |||
1) we compute the hungarian assignment between the ground-truth and predicted sequences. | |||
2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) | |||
""" | |||
def __init__(self, matcher, weight_dict, eos_coef): | |||
""" Create the criterion. | |||
Parameters: | |||
matcher: module able to compute a matching between targets and proposals | |||
weight_dict: dict containing as key the names of the losses and as values their relative weight. | |||
eos_coef: relative classification weight applied to the un-referred category | |||
""" | |||
super().__init__() | |||
self.matcher = matcher | |||
self.weight_dict = weight_dict | |||
self.eos_coef = eos_coef | |||
# make sure that only loss functions with non-zero weights are computed: | |||
losses_to_compute = [] | |||
if weight_dict['loss_dice'] > 0 or weight_dict[ | |||
'loss_sigmoid_focal'] > 0: | |||
losses_to_compute.append('masks') | |||
if weight_dict['loss_is_referred'] > 0: | |||
losses_to_compute.append('is_referred') | |||
self.losses = losses_to_compute | |||
def forward(self, outputs, targets): | |||
aux_outputs_list = outputs.pop('aux_outputs', None) | |||
# compute the losses for the output of the last decoder layer: | |||
losses = self.compute_criterion( | |||
outputs, targets, losses_to_compute=self.losses) | |||
# In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. | |||
if aux_outputs_list is not None: | |||
aux_losses_to_compute = self.losses.copy() | |||
for i, aux_outputs in enumerate(aux_outputs_list): | |||
losses_dict = self.compute_criterion(aux_outputs, targets, | |||
aux_losses_to_compute) | |||
losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} | |||
losses.update(losses_dict) | |||
return losses | |||
def compute_criterion(self, outputs, targets, losses_to_compute): | |||
# Retrieve the matching between the outputs of the last layer and the targets | |||
indices = self.matcher(outputs, targets) | |||
# T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). | |||
# also, indices are repeated so so the same indices can be used for frames of the same video. | |||
T = len(targets) | |||
outputs, targets = flatten_temporal_batch_dims(outputs, targets) | |||
# repeat the indices list T times so the same indices can be used for each video frame | |||
indices = T * indices | |||
# Compute the average number of target masks across all nodes, for normalization purposes | |||
num_masks = sum(len(t['masks']) for t in targets) | |||
num_masks = torch.as_tensor([num_masks], | |||
dtype=torch.float, | |||
device=indices[0][0].device) | |||
if is_dist_avail_and_initialized(): | |||
torch.distributed.all_reduce(num_masks) | |||
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() | |||
# Compute all the requested losses | |||
losses = {} | |||
for loss in losses_to_compute: | |||
losses.update( | |||
self.get_loss( | |||
loss, outputs, targets, indices, num_masks=num_masks)) | |||
return losses | |||
def loss_is_referred(self, outputs, targets, indices, **kwargs): | |||
device = outputs['pred_is_referred'].device | |||
bs = outputs['pred_is_referred'].shape[0] | |||
pred_is_referred = outputs['pred_is_referred'].log_softmax( | |||
dim=-1) # note that log-softmax is used here | |||
target_is_referred = torch.zeros_like(pred_is_referred) | |||
# extract indices of object queries that where matched with text-referred target objects | |||
query_referred_indices = self._get_query_referred_indices( | |||
indices, targets) | |||
# by default penalize compared to the no-object class (last token) | |||
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||
if 'is_ref_inst_visible' in targets[ | |||
0]: # visibility labels are available per-frame for the referred object: | |||
is_ref_inst_visible_per_frame = torch.stack( | |||
[t['is_ref_inst_visible'] for t in targets]) | |||
ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( | |||
).squeeze() | |||
# keep only the matched query indices of the frames in which the referred object is visible: | |||
visible_query_referred_indices = query_referred_indices[ | |||
ref_inst_visible_frame_indices] | |||
target_is_referred[ref_inst_visible_frame_indices, | |||
visible_query_referred_indices] = torch.tensor( | |||
[1.0, 0.0], device=device) | |||
else: # assume that the referred object is visible in every frame: | |||
target_is_referred[torch.arange(bs), | |||
query_referred_indices] = torch.tensor( | |||
[1.0, 0.0], device=device) | |||
loss = -(pred_is_referred * target_is_referred).sum(-1) | |||
# apply no-object class weights: | |||
eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) | |||
eos_coef[torch.arange(bs), query_referred_indices] = 1.0 | |||
loss = loss * eos_coef | |||
bs = len(indices) | |||
loss = loss.sum() / bs # sum and normalize the loss by the batch size | |||
losses = {'loss_is_referred': loss} | |||
return losses | |||
def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): | |||
assert 'pred_masks' in outputs | |||
src_idx = self._get_src_permutation_idx(indices) | |||
tgt_idx = self._get_tgt_permutation_idx(indices) | |||
src_masks = outputs['pred_masks'] | |||
src_masks = src_masks[src_idx] | |||
masks = [t['masks'] for t in targets] | |||
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() | |||
target_masks = target_masks.to(src_masks) | |||
target_masks = target_masks[tgt_idx] | |||
# upsample predictions to the target size | |||
src_masks = interpolate( | |||
src_masks[:, None], | |||
size=target_masks.shape[-2:], | |||
mode='bilinear', | |||
align_corners=False) | |||
src_masks = src_masks[:, 0].flatten(1) | |||
target_masks = target_masks.flatten(1) | |||
target_masks = target_masks.view(src_masks.shape) | |||
losses = { | |||
'loss_sigmoid_focal': | |||
sigmoid_focal_loss(src_masks, target_masks, num_masks), | |||
'loss_dice': | |||
dice_loss(src_masks, target_masks, num_masks), | |||
} | |||
return losses | |||
@staticmethod | |||
def _get_src_permutation_idx(indices): | |||
# permute predictions following indices | |||
batch_idx = torch.cat( | |||
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |||
src_idx = torch.cat([src for (src, _) in indices]) | |||
return batch_idx, src_idx | |||
@staticmethod | |||
def _get_tgt_permutation_idx(indices): | |||
# permute targets following indices | |||
batch_idx = torch.cat( | |||
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |||
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |||
return batch_idx, tgt_idx | |||
@staticmethod | |||
def _get_query_referred_indices(indices, targets): | |||
""" | |||
extract indices of object queries that where matched with text-referred target objects | |||
""" | |||
query_referred_indices = [] | |||
for (query_idxs, target_idxs), target in zip(indices, targets): | |||
ref_query_idx = query_idxs[torch.where( | |||
target_idxs == target['referred_instance_idx'])[0]] | |||
query_referred_indices.append(ref_query_idx) | |||
query_referred_indices = torch.cat(query_referred_indices) | |||
return query_referred_indices | |||
def get_loss(self, loss, outputs, targets, indices, **kwargs): | |||
loss_map = { | |||
'masks': self.loss_masks, | |||
'is_referred': self.loss_is_referred, | |||
} | |||
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |||
return loss_map[loss](outputs, targets, indices, **kwargs) | |||
def flatten_temporal_batch_dims(outputs, targets): | |||
for k in outputs.keys(): | |||
if isinstance(outputs[k], torch.Tensor): | |||
outputs[k] = outputs[k].flatten(0, 1) | |||
else: # list | |||
outputs[k] = [i for step_t in outputs[k] for i in step_t] | |||
targets = [ | |||
frame_t_target for step_t in targets for frame_t_target in step_t | |||
] | |||
return outputs, targets |
@@ -0,0 +1,163 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
# Module to compute the matching cost and solve the corresponding LSAP. | |||
import torch | |||
from scipy.optimize import linear_sum_assignment | |||
from torch import nn | |||
from .misc import interpolate, nested_tensor_from_tensor_list | |||
class HungarianMatcher(nn.Module): | |||
"""This class computes an assignment between the targets and the predictions of the network | |||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |||
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |||
while the others are un-matched (and thus treated as non-objects). | |||
""" | |||
def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): | |||
"""Creates the matcher | |||
Params: | |||
cost_is_referred: This is the relative weight of the reference cost in the total matching cost | |||
cost_dice: This is the relative weight of the dice cost in the total matching cost | |||
""" | |||
super().__init__() | |||
self.cost_is_referred = cost_is_referred | |||
self.cost_dice = cost_dice | |||
assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' | |||
@torch.inference_mode() | |||
def forward(self, outputs, targets): | |||
""" Performs the matching | |||
Params: | |||
outputs: A dict that contains at least these entries: | |||
"pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits | |||
"pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits | |||
targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict | |||
which contain mask and reference ground truth information for a single frame. | |||
Returns: | |||
A list of size batch_size, containing tuples of (index_i, index_j) where: | |||
- index_i is the indices of the selected predictions (in order) | |||
- index_j is the indices of the corresponding selected targets (in order) | |||
For each batch element, it holds: | |||
len(index_i) = len(index_j) = min(num_queries, num_target_masks) | |||
""" | |||
t, bs, num_queries = outputs['pred_masks'].shape[:3] | |||
# We flatten to compute the cost matrices in a batch | |||
out_masks = outputs['pred_masks'].flatten( | |||
1, 2) # [t, batch_size * num_queries, mask_h, mask_w] | |||
# preprocess and concat the target masks | |||
tgt_masks = [[ | |||
m for v in t_step_batch for m in v['masks'].unsqueeze(1) | |||
] for t_step_batch in targets] | |||
# pad the target masks to a uniform shape | |||
tgt_masks, valid = list( | |||
zip(*[ | |||
nested_tensor_from_tensor_list(t).decompose() | |||
for t in tgt_masks | |||
])) | |||
tgt_masks = torch.stack(tgt_masks).squeeze(2) | |||
# upsample predicted masks to target mask size | |||
out_masks = interpolate( | |||
out_masks, | |||
size=tgt_masks.shape[-2:], | |||
mode='bilinear', | |||
align_corners=False) | |||
# Compute the soft-tokens cost: | |||
if self.cost_is_referred > 0: | |||
cost_is_referred = compute_is_referred_cost(outputs, targets) | |||
else: | |||
cost_is_referred = 0 | |||
# Compute the DICE coefficient between the masks: | |||
if self.cost_dice > 0: | |||
cost_dice = -dice_coef(out_masks, tgt_masks) | |||
else: | |||
cost_dice = 0 | |||
# Final cost matrix | |||
C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice | |||
C = C.view(bs, num_queries, -1).cpu() | |||
num_traj_per_batch = [ | |||
len(v['masks']) for v in targets[0] | |||
] # number of instance trajectories in each batch | |||
indices = [ | |||
linear_sum_assignment(c[i]) | |||
for i, c in enumerate(C.split(num_traj_per_batch, -1)) | |||
] | |||
device = out_masks.device | |||
return [(torch.as_tensor(i, dtype=torch.int64, device=device), | |||
torch.as_tensor(j, dtype=torch.int64, device=device)) | |||
for i, j in indices] | |||
def dice_coef(inputs, targets, smooth=1.0): | |||
""" | |||
Compute the DICE coefficient, similar to generalized IOU for masks | |||
Args: | |||
inputs: A float tensor of arbitrary shape. | |||
The predictions for each example. | |||
targets: A float tensor with the same shape as inputs. Stores the binary | |||
classification label for each element in inputs | |||
(0 for the negative class and 1 for the positive class). | |||
""" | |||
inputs = inputs.sigmoid().flatten(2).unsqueeze(2) | |||
targets = targets.flatten(2).unsqueeze(1) | |||
numerator = 2 * (inputs * targets).sum(-1) | |||
denominator = inputs.sum(-1) + targets.sum(-1) | |||
coef = (numerator + smooth) / (denominator + smooth) | |||
coef = coef.mean( | |||
0) # average on the temporal dim to get instance trajectory scores | |||
return coef | |||
def compute_is_referred_cost(outputs, targets): | |||
pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( | |||
dim=-1) # [t, b*nq, 2] | |||
device = pred_is_referred.device | |||
t = pred_is_referred.shape[0] | |||
# number of instance trajectories in each batch | |||
num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], | |||
device=device) | |||
total_trajectories = num_traj_per_batch.sum() | |||
# note that ref_indices are shared across time steps: | |||
ref_indices = torch.tensor( | |||
[v['referred_instance_idx'] for v in targets[0]], device=device) | |||
# convert ref_indices to fit flattened batch targets: | |||
ref_indices += torch.cat( | |||
(torch.zeros(1, dtype=torch.long, | |||
device=device), num_traj_per_batch.cumsum(0)[:-1])) | |||
# number of instance trajectories in each batch | |||
target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) | |||
# 'no object' class by default (for un-referred objects) | |||
target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) | |||
if 'is_ref_inst_visible' in targets[0][ | |||
0]: # visibility labels are available per-frame for the referred object: | |||
is_ref_inst_visible = torch.stack([ | |||
torch.stack([t['is_ref_inst_visible'] for t in t_step]) | |||
for t_step in targets | |||
]).permute(1, 0) | |||
for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): | |||
is_visible = is_visible.nonzero().squeeze() | |||
target_is_referred[is_visible, | |||
ref_idx, :] = torch.tensor([1.0, 0.0], | |||
device=device) | |||
else: # assume that the referred object is visible in every frame: | |||
target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], | |||
device=device) | |||
cost_is_referred = -(pred_is_referred.unsqueeze(2) | |||
* target_is_referred.unsqueeze(1)).sum(dim=-1).mean( | |||
dim=0) | |||
return cost_is_referred |
@@ -122,8 +122,8 @@ class MultimodalTransformer(nn.Module): | |||
with torch.inference_mode(mode=self.freeze_text_encoder): | |||
encoded_text = self.text_encoder(**tokenized_queries) | |||
# Transpose memory because pytorch's attention expects sequence first | |||
txt_memory = rearrange(encoded_text.last_hidden_state, | |||
'b s c -> s b c') | |||
tmp_last_hidden_state = encoded_text.last_hidden_state.clone() | |||
txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') | |||
txt_memory = self.txt_proj( | |||
txt_memory) # change text embeddings dim to model dim | |||
# Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | |||
@@ -123,7 +123,8 @@ class WindowAttention3D(nn.Module): | |||
# define a parameter table of relative position bias | |||
wd, wh, ww = window_size | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), num_heads)) | |||
torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), | |||
num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH | |||
# get pair-wise relative position index for each token inside the window | |||
coords_d = torch.arange(self.window_size[0]) | |||
@@ -269,8 +269,11 @@ class TinyNAS(nn.Module): | |||
the_block_class = block_info['class'] | |||
if the_block_class == 'ConvKXBNRELU': | |||
if use_focus: | |||
the_block = Focus(block_info['in'], block_info['out'], | |||
block_info['k']) | |||
the_block = Focus( | |||
block_info['in'], | |||
block_info['out'], | |||
block_info['k'], | |||
act=act) | |||
else: | |||
the_block = ConvKXBNRELU( | |||
block_info['in'], | |||
@@ -6,6 +6,7 @@ import pickle | |||
import cv2 | |||
import torch | |||
import torch.nn as nn | |||
import torchvision | |||
from modelscope.metainfo import Models | |||
@@ -47,6 +48,7 @@ class SingleStageDetector(TorchModel): | |||
self.backbone = build_backbone(self.cfg.model.backbone) | |||
self.neck = build_neck(self.cfg.model.neck) | |||
self.head = build_head(self.cfg.model.head) | |||
self.apply(self.init_bn) | |||
self.load_pretrain_model(model_path) | |||
@@ -59,6 +61,12 @@ class SingleStageDetector(TorchModel): | |||
new_state_dict[k] = v | |||
self.load_state_dict(new_state_dict, strict=True) | |||
def init_bn(self, M): | |||
for m in M.modules(): | |||
if isinstance(m, nn.BatchNorm2d): | |||
m.eps = 1e-3 | |||
m.momentum = 0.03 | |||
def inference(self, x): | |||
if self.training: | |||
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from os import path as osp | |||
from typing import Any, Dict | |||
import json | |||
@@ -23,7 +24,8 @@ from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer | |||
from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg | |||
from modelscope.models.multi_modal.ofa.generate.search import Sampling | |||
from modelscope.models.multi_modal.ofa.generate.utils import move_to_device | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
try: | |||
from torchvision.transforms import InterpolationMode | |||
@@ -133,6 +135,8 @@ class OfaForTextToImageSynthesis(Model): | |||
super().__init__(model_dir=model_dir, *args, **kwargs) | |||
# Initialize ofa | |||
model = OFAModel.from_pretrained(model_dir) | |||
self.cfg = Config.from_file( | |||
osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
self.model = model.module if hasattr(model, 'module') else model | |||
self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
@@ -171,6 +175,8 @@ class OfaForTextToImageSynthesis(Model): | |||
'gen_code': True, | |||
'constraint_range': '50265,58457' | |||
} | |||
if hasattr(self.cfg.model, 'beam_search'): | |||
sg_args.update(self.cfg.model.beam_search) | |||
self.generator = sg.SequenceGenerator(**sg_args) | |||
def clip_tokenize(self, texts, context_length=77, truncate=False): | |||
@@ -8,7 +8,6 @@ from torch import nn | |||
from modelscope.metainfo import Heads | |||
from modelscope.models.base import TorchHead | |||
from modelscope.models.builder import HEADS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
@@ -27,9 +26,8 @@ class TextGenerationHead(TorchHead): | |||
def forward(self, inputs=None): | |||
logits = self.linear(inputs) | |||
return {OutputKeys.LOGITS: logits} | |||
return logits | |||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
def compute_loss(self, logits: torch.Tensor, | |||
labels) -> Dict[str, torch.Tensor]: | |||
logits = outputs[OutputKeys.LOGITS] | |||
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||
return F.cross_entropy(logits, labels) |
@@ -1,7 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import addict | |||
import numpy as np | |||
from transformers.modeling_utils import PreTrainedModel | |||
@@ -9,7 +8,8 @@ from modelscope.metainfo import TaskModels | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.task_models.task_model import \ | |||
SingleBackboneTaskModelBase | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.outputs import (OutputKeys, TextGenerationModelOutput, | |||
TokenGeneratorOutput) | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['TaskModelForTextGeneration'] | |||
@@ -43,12 +43,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
backbone_outputs = super().forward(input) | |||
hidden_states = backbone_outputs[0] | |||
outputs = self.head.forward(hidden_states) | |||
logits = self.head.forward(hidden_states) | |||
loss = None | |||
if labels is not None: | |||
input[OutputKeys.LABELS] = labels | |||
loss = self.compute_loss(outputs, labels) | |||
outputs.update(loss) | |||
return addict.Dict(outputs) | |||
loss = self.compute_loss(logits, labels) | |||
return TextGenerationModelOutput(logits=logits, loss=loss) | |||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |||
# only last token for inputs_ids if past is defined in kwargs | |||
@@ -76,4 +76,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
def generate(self, inputs, *args, **kwargs): | |||
input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs | |||
return super().generate(input_ids, *args, **kwargs) | |||
generate_output = super().generate(input_ids, *args, **kwargs) | |||
if isinstance(generate_output, Dict): | |||
return TokenGeneratorOutput( | |||
sequences=generate_output.sequences, | |||
scores=generate_output.scores, | |||
attentions=generate_output.attentions, | |||
hidden_states=generate_output.hidden_states) | |||
else: | |||
return TokenGeneratorOutput(sequences=generate_output) |
@@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
attentions=outputs.attentions, | |||
offset_mapping=input['offset_mapping'], | |||
) | |||
return outputs | |||
def extract_logits(self, outputs): | |||
return outputs[OutputKeys.LOGITS].cpu().detach() |
@@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): | |||
pt_att *= pt_att | |||
pt_att = pt_att.sum(dim=-1) | |||
head_weights = self.softplus(self.head_weights).view( | |||
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||
head_weights = self.softplus(self.head_weights).view( # noqa | |||
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa | |||
head_weights = head_weights * math.sqrt( | |||
1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |||
pt_att *= head_weights * (-0.5) | |||
@@ -20,13 +20,15 @@ from modelscope.msdatasets.task_datasets.builder import build_task_dataset | |||
from modelscope.msdatasets.utils.dataset_builder import ExternalDataset | |||
from modelscope.msdatasets.utils.dataset_utils import ( | |||
get_dataset_files, get_target_dataset_structure, load_dataset_builder) | |||
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager | |||
from modelscope.msdatasets.utils.download_utils import DatasetDownloadManager | |||
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.config_ds import MS_DATASETS_CACHE | |||
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, | |||
DEFAULT_DATASET_REVISION, | |||
DatasetFormations, DownloadMode, Hubs) | |||
DatasetFormations, DownloadMode, Hubs, | |||
UploadMode) | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -576,15 +578,17 @@ class MsDataset: | |||
return self._hf_ds.rename_columns(column_mapping) | |||
@staticmethod | |||
def upload(object_name: str, | |||
local_file_path: str, | |||
dataset_name: str, | |||
namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
version: Optional[str] = DEFAULT_DATASET_REVISION, | |||
num_processes: Optional[int] = None, | |||
chunksize: Optional[int] = 1, | |||
filter_hidden_files: Optional[bool] = True) -> None: | |||
"""Upload dataset file or directory to the ModelScope Hub. Please login to the ModelScope Hub first. | |||
def upload( | |||
object_name: str, | |||
local_file_path: str, | |||
dataset_name: str, | |||
namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
version: Optional[str] = DEFAULT_DATASET_REVISION, | |||
num_processes: Optional[int] = None, | |||
chunksize: Optional[int] = 1, | |||
filter_hidden_files: Optional[bool] = True, | |||
upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None: | |||
"""Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first. | |||
Args: | |||
object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name | |||
@@ -592,7 +596,7 @@ class MsDataset: | |||
dataset_name (str): Name of the dataset | |||
namespace(str, optional): Namespace of the dataset | |||
version: Optional[str]: Version of the dataset | |||
num_processes: Optional[int]: The number of processes used for multi-process uploading. | |||
num_processes: Optional[int]: The number of processes used for multiprocess uploading. | |||
This is only applicable when local_file_path is a directory, and we are uploading mutliple-files | |||
insided the directory. When None provided, the number returned by os.cpu_count() is used as default. | |||
chunksize: Optional[int]: The chunksize of objects to upload. | |||
@@ -600,24 +604,34 @@ class MsDataset: | |||
using the default value of 1. Available if local_file_path is a directory. | |||
filter_hidden_files: Optional[bool]: Whether to filter hidden files. | |||
Available if local_file_path is a directory. | |||
upload_mode: Optional[UploadMode]: How to upload objects from local. Default: UploadMode.OVERWRITE, upload | |||
all objects from local, existing remote objects may be overwritten. | |||
Returns: | |||
None | |||
""" | |||
if not object_name: | |||
raise ValueError('object_name cannot be empty!') | |||
_upload_manager = DatasetUploadManager( | |||
dataset_name=dataset_name, namespace=namespace, version=version) | |||
upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE) | |||
if os.path.isfile(local_file_path): | |||
_upload_manager.upload( | |||
object_name=object_name, local_file_path=local_file_path) | |||
object_name=object_name, | |||
local_file_path=local_file_path, | |||
upload_mode=upload_mode) | |||
elif os.path.isdir(local_file_path): | |||
_upload_manager.upload_dir( | |||
object_dir_name=object_name, | |||
local_dir_path=local_file_path, | |||
num_processes=num_processes, | |||
chunksize=chunksize, | |||
filter_hidden_files=filter_hidden_files) | |||
filter_hidden_files=filter_hidden_files, | |||
upload_mode=upload_mode) | |||
else: | |||
raise ValueError( | |||
f'{local_file_path} is not a valid file path or directory') | |||
@@ -672,7 +686,7 @@ class MsDataset: | |||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash | |||
auth_token(`Optional[str]`): | |||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter | |||
as the token is already saved when you login the first time, if None, we will use saved token. | |||
as the token is already saved when you log in the first time, if None, we will use saved token. | |||
git_path:(`Optional[str]`): | |||
The git command line path, if None, we use 'git' | |||
force (Optional[bool]): whether to use forced-push. | |||
@@ -687,8 +701,29 @@ class MsDataset: | |||
revision=revision, | |||
auth_token=auth_token, | |||
git_path=git_path) | |||
_repo.push( | |||
commit_message=commit_message, | |||
local_branch=revision, | |||
remote_branch=revision, | |||
force=force) | |||
_repo.push(commit_message=commit_message, branch=revision, force=force) | |||
@staticmethod | |||
def delete(object_name: str, | |||
dataset_name: str, | |||
namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, | |||
version: Optional[str] = DEFAULT_DATASET_REVISION) -> str: | |||
""" Delete object of dataset. Please log in first and make sure you have permission to manage the dataset. | |||
Args: | |||
object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's | |||
directory, then ends with `/`. | |||
For example: your-data-name.zip, train/001/img_001.png, train/, ... | |||
dataset_name (str): Path or name of the dataset. | |||
namespace(str, optional): Namespace of the dataset. | |||
version (str, optional): Version of the dataset. | |||
Returns: | |||
res_msg (str): Response message. | |||
""" | |||
_delete_manager = DatasetDeleteManager( | |||
dataset_name=dataset_name, namespace=namespace, version=version) | |||
resp_msg = _delete_manager.delete(object_name=object_name) | |||
logger.info(f'Object {object_name} successfully removed!') | |||
return resp_msg |
@@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||
from .video_summarization_dataset import VideoSummarizationDataset | |||
from .image_inpainting import ImageInpaintingDataset | |||
from .text_ranking_dataset import TextRankingDataset | |||
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | |||
else: | |||
_import_structure = { | |||
@@ -29,6 +30,8 @@ else: | |||
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | |||
'image_portrait_enhancement_dataset': | |||
['ImagePortraitEnhancementDataset'], | |||
'referring_video_object_segmentation': | |||
['ReferringVideoObjectSegmentationDataset'], | |||
} | |||
import sys | |||
@@ -0,0 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .referring_video_object_segmentation_dataset import \ | |||
ReferringVideoObjectSegmentationDataset |
@@ -0,0 +1,361 @@ | |||
# Part of the implementation is borrowed and modified from MTTR, | |||
# publicly available at https://github.com/mttr2021/MTTR | |||
from glob import glob | |||
from os import path as osp | |||
import h5py | |||
import json | |||
import numpy as np | |||
import pandas | |||
import torch | |||
import torch.distributed as dist | |||
import torchvision.transforms.functional as F | |||
from pycocotools.mask import area, encode | |||
from torchvision.io import read_video | |||
from tqdm import tqdm | |||
from modelscope.metainfo import Models | |||
from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||
nested_tensor_from_videos_list | |||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
TorchTaskDataset | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from . import transformers as T | |||
LOGGER = get_logger() | |||
def get_image_id(video_id, frame_idx, ref_instance_a2d_id): | |||
image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' | |||
return image_id | |||
@TASK_DATASETS.register_module( | |||
Tasks.referring_video_object_segmentation, | |||
module_name=Models.referring_video_object_segmentation) | |||
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): | |||
def __init__(self, **kwargs): | |||
split_config = kwargs['split_config'] | |||
LOGGER.info(kwargs) | |||
data_cfg = kwargs.get('cfg').data_kwargs | |||
trans_cfg = kwargs.get('cfg').transformers_kwargs | |||
distributed = data_cfg.get('distributed', False) | |||
self.data_root = next(iter(split_config.values())) | |||
if not osp.exists(self.data_root): | |||
self.data_root = osp.dirname(self.data_root) | |||
assert osp.exists(self.data_root) | |||
self.window_size = data_cfg.get('window_size', 8) | |||
self.mask_annotations_dir = osp.join( | |||
self.data_root, 'text_annotations/annotation_with_instances') | |||
self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') | |||
self.subset_type = next(iter(split_config.keys())) | |||
self.text_annotations = self.get_text_annotations( | |||
self.data_root, self.subset_type, distributed) | |||
self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) | |||
self.collator = Collator() | |||
self.ann_file = osp.join( | |||
self.data_root, | |||
data_cfg.get('ann_file', | |||
'a2d_sentences_test_annotations_in_coco_format.json')) | |||
# create ground-truth test annotations for the evaluation process if necessary: | |||
if self.subset_type == 'test' and not osp.exists(self.ann_file): | |||
if (distributed and dist.get_rank() == 0) or not distributed: | |||
create_a2d_sentences_ground_truth_test_annotations( | |||
self.data_root, self.subset_type, | |||
self.mask_annotations_dir, self.ann_file) | |||
if distributed: | |||
dist.barrier() | |||
def __len__(self): | |||
return len(self.text_annotations) | |||
def __getitem__(self, idx): | |||
text_query, video_id, frame_idx, instance_id = self.text_annotations[ | |||
idx] | |||
text_query = ' '.join( | |||
text_query.lower().split()) # clean up the text query | |||
# read the source window frames: | |||
video_frames, _, _ = read_video( | |||
osp.join(self.videos_dir, f'{video_id}.mp4'), | |||
pts_unit='sec') # (T, H, W, C) | |||
# get a window of window_size frames with frame frame_idx in the middle. | |||
# note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx | |||
start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( | |||
self.window_size + 1) // 2 | |||
# extract the window source frames: | |||
source_frames = [] | |||
for i in range(start_idx, end_idx): | |||
i = min(max(i, 0), | |||
len(video_frames) | |||
- 1) # pad out of range indices with edge frames | |||
source_frames.append( | |||
F.to_pil_image(video_frames[i].permute(2, 0, 1))) | |||
# read the instance mask: | |||
frame_annot_path = osp.join(self.mask_annotations_dir, video_id, | |||
f'{frame_idx:05d}.h5') | |||
f = h5py.File(frame_annot_path, 'r') | |||
instances = list(f['instance']) | |||
instance_idx = instances.index( | |||
instance_id) # existence was already validated during init | |||
instance_masks = np.array(f['reMask']) | |||
if len(instances) == 1: | |||
instance_masks = instance_masks[np.newaxis, ...] | |||
instance_masks = torch.tensor(instance_masks).transpose(1, 2) | |||
mask_rles = [encode(mask) for mask in instance_masks.numpy()] | |||
mask_areas = area(mask_rles).astype(np.float) | |||
f.close() | |||
# create the target dict for the center frame: | |||
target = { | |||
'masks': instance_masks, | |||
'orig_size': instance_masks. | |||
shape[-2:], # original frame shape without any augmentations | |||
# size with augmentations, will be changed inside transforms if necessary | |||
'size': instance_masks.shape[-2:], | |||
'referred_instance_idx': torch.tensor( | |||
instance_idx), # idx in 'masks' of the text referred instance | |||
'area': torch.tensor(mask_areas), | |||
'iscrowd': | |||
torch.zeros(len(instance_masks) | |||
), # for compatibility with DETR COCO transforms | |||
'image_id': get_image_id(video_id, frame_idx, instance_id) | |||
} | |||
# create dummy targets for adjacent frames: | |||
targets = self.window_size * [None] | |||
center_frame_idx = self.window_size // 2 | |||
targets[center_frame_idx] = target | |||
source_frames, targets, text_query = self.transforms( | |||
source_frames, targets, text_query) | |||
return source_frames, targets, text_query | |||
@staticmethod | |||
def get_text_annotations(root_path, subset, distributed): | |||
saved_annotations_file_path = osp.join( | |||
root_path, f'sentences_single_frame_{subset}_annotations.json') | |||
if osp.exists(saved_annotations_file_path): | |||
with open(saved_annotations_file_path, 'r') as f: | |||
text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||
return text_annotations_by_frame | |||
elif (distributed and dist.get_rank() == 0) or not distributed: | |||
print(f'building a2d sentences {subset} text annotations...') | |||
# without 'header == None' pandas will ignore the first sample... | |||
a2d_data_info = pandas.read_csv( | |||
osp.join(root_path, 'Release/videoset.csv'), header=None) | |||
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||
a2d_data_info.columns = [ | |||
'vid', '', '', '', '', '', '', '', 'subset' | |||
] | |||
with open( | |||
osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||
'r') as f: | |||
unused_videos = f.read().splitlines() | |||
subsets = {'train': 0, 'test': 1} | |||
# filter unused videos and videos which do not belong to our train/test subset: | |||
used_videos = a2d_data_info[ | |||
~a2d_data_info.vid.isin(unused_videos) | |||
& (a2d_data_info.subset == subsets[subset])] | |||
used_videos_ids = list(used_videos['vid']) | |||
text_annotations = pandas.read_csv( | |||
osp.join(root_path, 'text_annotations/annotation.txt')) | |||
# filter the text annotations based on the used videos: | |||
used_text_annotations = text_annotations[ | |||
text_annotations.video_id.isin(used_videos_ids)] | |||
# remove a single dataset annotation mistake in video: T6bNPuKV-wY | |||
used_text_annotations = used_text_annotations[ | |||
used_text_annotations['instance_id'] != '1 (copy)'] | |||
# convert data-frame to list of tuples: | |||
used_text_annotations = list( | |||
used_text_annotations.to_records(index=False)) | |||
text_annotations_by_frame = [] | |||
mask_annotations_dir = osp.join( | |||
root_path, 'text_annotations/annotation_with_instances') | |||
for video_id, instance_id, text_query in tqdm( | |||
used_text_annotations): | |||
frame_annot_paths = sorted( | |||
glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||
instance_id = int(instance_id) | |||
for p in frame_annot_paths: | |||
f = h5py.File(p) | |||
instances = list(f['instance']) | |||
if instance_id in instances: | |||
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||
frame_idx = int(p.split('/')[-1].split('.')[0]) | |||
text_query = text_query.lower( | |||
) # lower the text query prior to augmentation & tokenization | |||
text_annotations_by_frame.append( | |||
(text_query, video_id, frame_idx, instance_id)) | |||
with open(saved_annotations_file_path, 'w') as f: | |||
json.dump(text_annotations_by_frame, f) | |||
if distributed: | |||
dist.barrier() | |||
with open(saved_annotations_file_path, 'r') as f: | |||
text_annotations_by_frame = [tuple(a) for a in json.load(f)] | |||
return text_annotations_by_frame | |||
class A2dSentencesTransforms: | |||
def __init__(self, subset_type, horizontal_flip_augmentations, | |||
resize_and_crop_augmentations, train_short_size, | |||
train_max_size, eval_short_size, eval_max_size, **kwargs): | |||
self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations | |||
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |||
scales = [ | |||
train_short_size | |||
] # no more scales for now due to GPU memory constraints. might be changed later | |||
transforms = [] | |||
if resize_and_crop_augmentations: | |||
if subset_type == 'train': | |||
transforms.append( | |||
T.RandomResize(scales, max_size=train_max_size)) | |||
elif subset_type == 'test': | |||
transforms.append( | |||
T.RandomResize([eval_short_size], max_size=eval_max_size)), | |||
transforms.extend([T.ToTensor(), normalize]) | |||
self.size_transforms = T.Compose(transforms) | |||
def __call__(self, source_frames, targets, text_query): | |||
if self.h_flip_augmentation and torch.rand(1) > 0.5: | |||
source_frames = [F.hflip(f) for f in source_frames] | |||
targets[len(targets) // 2]['masks'] = F.hflip( | |||
targets[len(targets) // 2]['masks']) | |||
# Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: | |||
text_query = text_query.replace('left', '@').replace( | |||
'right', 'left').replace('@', 'right') | |||
source_frames, targets = list( | |||
zip(*[ | |||
self.size_transforms(f, t) | |||
for f, t in zip(source_frames, targets) | |||
])) | |||
source_frames = torch.stack(source_frames) # [T, 3, H, W] | |||
return source_frames, targets, text_query | |||
class Collator: | |||
def __call__(self, batch): | |||
samples, targets, text_queries = list(zip(*batch)) | |||
samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] | |||
# convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch | |||
targets = list(zip(*targets)) | |||
batch_dict = { | |||
'samples': samples, | |||
'targets': targets, | |||
'text_queries': text_queries | |||
} | |||
return batch_dict | |||
def get_text_annotations_gt(root_path, subset): | |||
# without 'header == None' pandas will ignore the first sample... | |||
a2d_data_info = pandas.read_csv( | |||
osp.join(root_path, 'Release/videoset.csv'), header=None) | |||
# 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' | |||
a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] | |||
with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), | |||
'r') as f: | |||
unused_videos = f.read().splitlines() | |||
subsets = {'train': 0, 'test': 1} | |||
# filter unused videos and videos which do not belong to our train/test subset: | |||
used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) | |||
& (a2d_data_info.subset == subsets[subset])] | |||
used_videos_ids = list(used_videos['vid']) | |||
text_annotations = pandas.read_csv( | |||
osp.join(root_path, 'text_annotations/annotation.txt')) | |||
# filter the text annotations based on the used videos: | |||
used_text_annotations = text_annotations[text_annotations.video_id.isin( | |||
used_videos_ids)] | |||
# convert data-frame to list of tuples: | |||
used_text_annotations = list(used_text_annotations.to_records(index=False)) | |||
return used_text_annotations | |||
def create_a2d_sentences_ground_truth_test_annotations(dataset_path, | |||
subset_type, | |||
mask_annotations_dir, | |||
output_path): | |||
text_annotations = get_text_annotations_gt(dataset_path, subset_type) | |||
# Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly | |||
# expected by pycocotools as it is the convention of the original coco dataset annotations. | |||
categories_dict = [{ | |||
'id': 1, | |||
'name': 'dummy_class' | |||
}] # dummy class, as categories are not used/predicted in RVOS | |||
images_dict = [] | |||
annotations_dict = [] | |||
images_set = set() | |||
instance_id_counter = 1 | |||
for annot in tqdm(text_annotations): | |||
video_id, instance_id, text_query = annot | |||
annot_paths = sorted( | |||
glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) | |||
for p in annot_paths: | |||
f = h5py.File(p) | |||
instances = list(f['instance']) | |||
try: | |||
instance_idx = instances.index(int(instance_id)) | |||
# in case this instance does not appear in this frame it has no ground-truth mask, and thus this | |||
# frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: | |||
# https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 | |||
except ValueError: | |||
continue # instance_id does not appear in current frame | |||
mask = f['reMask'][instance_idx] if len( | |||
instances) > 1 else np.array(f['reMask']) | |||
mask = mask.transpose() | |||
frame_idx = int(p.split('/')[-1].split('.')[0]) | |||
image_id = get_image_id(video_id, frame_idx, instance_id) | |||
assert image_id not in images_set, f'error: image id: {image_id} appeared twice' | |||
images_set.add(image_id) | |||
images_dict.append({ | |||
'id': image_id, | |||
'height': mask.shape[0], | |||
'width': mask.shape[1] | |||
}) | |||
mask_rle = encode(mask) | |||
mask_rle['counts'] = mask_rle['counts'].decode('ascii') | |||
mask_area = float(area(mask_rle)) | |||
bbox = f['reBBox'][:, instance_idx] if len( | |||
instances) > 1 else np.array( | |||
f['reBBox']).squeeze() # x1y1x2y2 form | |||
bbox_xywh = [ | |||
bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] | |||
] | |||
instance_annot = { | |||
'id': instance_id_counter, | |||
'image_id': image_id, | |||
'category_id': | |||
1, # dummy class, as categories are not used/predicted in ref-vos | |||
'segmentation': mask_rle, | |||
'area': mask_area, | |||
'bbox': bbox_xywh, | |||
'iscrowd': 0, | |||
} | |||
annotations_dict.append(instance_annot) | |||
instance_id_counter += 1 | |||
dataset_dict = { | |||
'categories': categories_dict, | |||
'images': images_dict, | |||
'annotations': annotations_dict | |||
} | |||
with open(output_path, 'w') as f: | |||
json.dump(dataset_dict, f) |
@@ -0,0 +1,294 @@ | |||
# The implementation is adopted from MTTR, | |||
# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR | |||
# Modified from DETR https://github.com/facebookresearch/detr | |||
import random | |||
import PIL | |||
import torch | |||
import torchvision.transforms as T | |||
import torchvision.transforms.functional as F | |||
from modelscope.models.cv.referring_video_object_segmentation.utils import \ | |||
interpolate | |||
def crop(image, target, region): | |||
cropped_image = F.crop(image, *region) | |||
target = target.copy() | |||
i, j, h, w = region | |||
# should we do something wrt the original size? | |||
target['size'] = torch.tensor([h, w]) | |||
fields = ['labels', 'area', 'iscrowd'] | |||
if 'boxes' in target: | |||
boxes = target['boxes'] | |||
max_size = torch.as_tensor([w, h], dtype=torch.float32) | |||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) | |||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) | |||
cropped_boxes = cropped_boxes.clamp(min=0) | |||
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) | |||
target['boxes'] = cropped_boxes.reshape(-1, 4) | |||
target['area'] = area | |||
fields.append('boxes') | |||
if 'masks' in target: | |||
# FIXME should we update the area here if there are no boxes? | |||
target['masks'] = target['masks'][:, i:i + h, j:j + w] | |||
fields.append('masks') | |||
# remove elements for which the boxes or masks that have zero area | |||
if 'boxes' in target or 'masks' in target: | |||
# favor boxes selection when defining which elements to keep | |||
# this is compatible with previous implementation | |||
if 'boxes' in target: | |||
cropped_boxes = target['boxes'].reshape(-1, 2, 2) | |||
keep = torch.all( | |||
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) | |||
else: | |||
keep = target['masks'].flatten(1).any(1) | |||
for field in fields: | |||
target[field] = target[field][keep] | |||
return cropped_image, target | |||
def hflip(image, target): | |||
flipped_image = F.hflip(image) | |||
w, h = image.size | |||
target = target.copy() | |||
if 'boxes' in target: | |||
boxes = target['boxes'] | |||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( | |||
[-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) | |||
target['boxes'] = boxes | |||
if 'masks' in target: | |||
target['masks'] = target['masks'].flip(-1) | |||
return flipped_image, target | |||
def resize(image, target, size, max_size=None): | |||
# size can be min_size (scalar) or (w, h) tuple | |||
def get_size_with_aspect_ratio(image_size, size, max_size=None): | |||
w, h = image_size | |||
if max_size is not None: | |||
min_original_size = float(min((w, h))) | |||
max_original_size = float(max((w, h))) | |||
if max_original_size / min_original_size * size > max_size: | |||
size = int( | |||
round(max_size * min_original_size / max_original_size)) | |||
if (w <= h and w == size) or (h <= w and h == size): | |||
return (h, w) | |||
if w < h: | |||
ow = size | |||
oh = int(size * h / w) | |||
else: | |||
oh = size | |||
ow = int(size * w / h) | |||
return (oh, ow) | |||
def get_size(image_size, size, max_size=None): | |||
if isinstance(size, (list, tuple)): | |||
return size[::-1] | |||
else: | |||
return get_size_with_aspect_ratio(image_size, size, max_size) | |||
size = get_size(image.size, size, max_size) | |||
rescaled_image = F.resize(image, size) | |||
if target is None: | |||
return rescaled_image, None | |||
ratios = tuple( | |||
float(s) / float(s_orig) | |||
for s, s_orig in zip(rescaled_image.size, image.size)) | |||
ratio_width, ratio_height = ratios | |||
target = target.copy() | |||
if 'boxes' in target: | |||
boxes = target['boxes'] | |||
scaled_boxes = boxes * torch.as_tensor( | |||
[ratio_width, ratio_height, ratio_width, ratio_height]) | |||
target['boxes'] = scaled_boxes | |||
if 'area' in target: | |||
area = target['area'] | |||
scaled_area = area * (ratio_width * ratio_height) | |||
target['area'] = scaled_area | |||
h, w = size | |||
target['size'] = torch.tensor([h, w]) | |||
if 'masks' in target: | |||
target['masks'] = interpolate( | |||
target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 | |||
return rescaled_image, target | |||
def pad(image, target, padding): | |||
# assumes that we only pad on the bottom right corners | |||
padded_image = F.pad(image, (0, 0, padding[0], padding[1])) | |||
if target is None: | |||
return padded_image, None | |||
target = target.copy() | |||
# should we do something wrt the original size? | |||
target['size'] = torch.tensor(padded_image.size[::-1]) | |||
if 'masks' in target: | |||
target['masks'] = torch.nn.functional.pad( | |||
target['masks'], (0, padding[0], 0, padding[1])) | |||
return padded_image, target | |||
class RandomCrop(object): | |||
def __init__(self, size): | |||
self.size = size | |||
def __call__(self, img, target): | |||
region = T.RandomCrop.get_params(img, self.size) | |||
return crop(img, target, region) | |||
class RandomSizeCrop(object): | |||
def __init__(self, min_size: int, max_size: int): | |||
self.min_size = min_size | |||
self.max_size = max_size | |||
def __call__(self, img: PIL.Image.Image, target: dict): | |||
w = random.randint(self.min_size, min(img.width, self.max_size)) | |||
h = random.randint(self.min_size, min(img.height, self.max_size)) | |||
region = T.RandomCrop.get_params(img, [h, w]) | |||
return crop(img, target, region) | |||
class CenterCrop(object): | |||
def __init__(self, size): | |||
self.size = size | |||
def __call__(self, img, target): | |||
image_width, image_height = img.size | |||
crop_height, crop_width = self.size | |||
crop_top = int(round((image_height - crop_height) / 2.)) | |||
crop_left = int(round((image_width - crop_width) / 2.)) | |||
return crop(img, target, | |||
(crop_top, crop_left, crop_height, crop_width)) | |||
class RandomHorizontalFlip(object): | |||
def __init__(self, p=0.5): | |||
self.p = p | |||
def __call__(self, img, target): | |||
if random.random() < self.p: | |||
return hflip(img, target) | |||
return img, target | |||
class RandomResize(object): | |||
def __init__(self, sizes, max_size=None): | |||
assert isinstance(sizes, (list, tuple)) | |||
self.sizes = sizes | |||
self.max_size = max_size | |||
def __call__(self, img, target=None): | |||
size = random.choice(self.sizes) | |||
return resize(img, target, size, self.max_size) | |||
class RandomPad(object): | |||
def __init__(self, max_pad): | |||
self.max_pad = max_pad | |||
def __call__(self, img, target): | |||
pad_x = random.randint(0, self.max_pad) | |||
pad_y = random.randint(0, self.max_pad) | |||
return pad(img, target, (pad_x, pad_y)) | |||
class RandomSelect(object): | |||
""" | |||
Randomly selects between transforms1 and transforms2, | |||
with probability p for transforms1 and (1 - p) for transforms2 | |||
""" | |||
def __init__(self, transforms1, transforms2, p=0.5): | |||
self.transforms1 = transforms1 | |||
self.transforms2 = transforms2 | |||
self.p = p | |||
def __call__(self, img, target): | |||
if random.random() < self.p: | |||
return self.transforms1(img, target) | |||
return self.transforms2(img, target) | |||
class ToTensor(object): | |||
def __call__(self, img, target): | |||
return F.to_tensor(img), target | |||
class RandomErasing(object): | |||
def __init__(self, *args, **kwargs): | |||
self.eraser = T.RandomErasing(*args, **kwargs) | |||
def __call__(self, img, target): | |||
return self.eraser(img), target | |||
class Normalize(object): | |||
def __init__(self, mean, std): | |||
self.mean = mean | |||
self.std = std | |||
def __call__(self, image, target=None): | |||
image = F.normalize(image, mean=self.mean, std=self.std) | |||
if target is None: | |||
return image, None | |||
target = target.copy() | |||
h, w = image.shape[-2:] | |||
if 'boxes' in target: | |||
boxes = target['boxes'] | |||
boxes = box_xyxy_to_cxcywh(boxes) | |||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) | |||
target['boxes'] = boxes | |||
return image, target | |||
class Compose(object): | |||
def __init__(self, transforms): | |||
self.transforms = transforms | |||
def __call__(self, image, target): | |||
for t in self.transforms: | |||
image, target = t(image, target) | |||
return image, target | |||
def __repr__(self): | |||
format_string = self.__class__.__name__ + '(' | |||
for t in self.transforms: | |||
format_string += '\n' | |||
format_string += ' {0}'.format(t) | |||
format_string += '\n)' | |||
return format_string |
@@ -82,7 +82,7 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, | |||
dataset_name: str, namespace: str, | |||
version: str) -> list: | |||
""" | |||
List all of objects for specific dataset. | |||
List all objects for specific dataset. | |||
Args: | |||
hub_api (class HubApi): HubApi instance. | |||
@@ -0,0 +1,32 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.hub.api import HubApi | |||
class DatasetDeleteManager(object): | |||
def __init__(self, dataset_name: str, namespace: str, version: str): | |||
self.api = HubApi() | |||
self.dataset_name = dataset_name | |||
self.namespace = namespace | |||
self.version = version | |||
def delete(self, object_name: str) -> str: | |||
# single object | |||
if not object_name.endswith('/'): | |||
resp_msg = self.api.delete_oss_dataset_object( | |||
object_name=object_name, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace, | |||
revision=self.version) | |||
else: | |||
# multiple objects | |||
object_name = object_name.strip('/') | |||
resp_msg = self.api.delete_oss_dataset_dir( | |||
object_name=object_name, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace, | |||
revision=self.version) | |||
return resp_msg |
@@ -27,7 +27,11 @@ class DatasetDownloadManager(DownloadManager): | |||
oss_config = api.get_dataset_access_config(self._dataset_name, | |||
self._namespace, | |||
self._version) | |||
self.oss_utilities = OssUtilities(oss_config) | |||
self.oss_utilities = OssUtilities( | |||
oss_config=oss_config, | |||
dataset_name=self._dataset_name, | |||
namespace=self._namespace, | |||
revision=self._version) | |||
def _download(self, url_or_filename: str, | |||
download_config: DownloadConfig) -> str: | |||
@@ -6,19 +6,28 @@ import os | |||
import oss2 | |||
from datasets.utils.file_utils import hash_url_to_filename | |||
from modelscope.hub.api import HubApi | |||
from modelscope.utils.constant import UploadMode | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
ACCESS_ID = 'AccessId' | |||
ACCESS_SECRET = 'AccessSecret' | |||
SECURITY_TOKEN = 'SecurityToken' | |||
BUCKET = 'Bucket' | |||
BACK_DIR = 'BackupDir' | |||
DIR = 'Dir' | |||
class OssUtilities: | |||
def __init__(self, oss_config): | |||
self.key = oss_config['AccessId'] | |||
self.secret = oss_config['AccessSecret'] | |||
self.token = oss_config['SecurityToken'] | |||
self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" | |||
self.bucket_name = oss_config['Bucket'] | |||
auth = oss2.StsAuth(self.key, self.secret, self.token) | |||
self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) | |||
self.oss_dir = oss_config['Dir'] | |||
self.oss_backup_dir = oss_config['BackupDir'] | |||
def __init__(self, oss_config, dataset_name, namespace, revision): | |||
self._do_init(oss_config=oss_config) | |||
self.dataset_name = dataset_name | |||
self.namespace = namespace | |||
self.revision = revision | |||
self.upload_resumable_tmp_store = '/tmp/modelscope/tmp_dataset' | |||
self.upload_multipart_threshold = 50 * 1024 * 1024 | |||
@@ -26,6 +35,28 @@ class OssUtilities: | |||
self.upload_num_threads = 4 | |||
self.upload_max_retries = 3 | |||
self.api = HubApi() | |||
def _do_init(self, oss_config): | |||
self.key = oss_config[ACCESS_ID] | |||
self.secret = oss_config[ACCESS_SECRET] | |||
self.token = oss_config[SECURITY_TOKEN] | |||
self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" | |||
self.bucket_name = oss_config[BUCKET] | |||
auth = oss2.StsAuth(self.key, self.secret, self.token) | |||
self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) | |||
self.oss_dir = oss_config[DIR] | |||
self.oss_backup_dir = oss_config[BACK_DIR] | |||
def _reload_sts(self): | |||
cookies = self.api.check_local_cookies(use_cookies=True) | |||
oss_config_refresh = self.api.get_dataset_access_config_session( | |||
cookies=cookies, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace, | |||
revision=self.revision) | |||
self._do_init(oss_config_refresh) | |||
@staticmethod | |||
def _percentage(consumed_bytes, total_bytes): | |||
if total_bytes: | |||
@@ -51,7 +82,8 @@ class OssUtilities: | |||
return local_path | |||
def upload(self, oss_object_name: str, local_file_path: str, | |||
indicate_individual_progress: bool) -> str: | |||
indicate_individual_progress: bool, | |||
upload_mode: UploadMode) -> str: | |||
retry_count = 0 | |||
object_key = os.path.join(self.oss_dir, oss_object_name) | |||
resumable_store = oss2.ResumableStore( | |||
@@ -64,6 +96,13 @@ class OssUtilities: | |||
while True: | |||
try: | |||
retry_count += 1 | |||
exist = self.bucket.object_exists(object_key) | |||
if upload_mode == UploadMode.APPEND and exist: | |||
logger.info( | |||
f'Skip {oss_object_name} in case of {upload_mode.value} mode.' | |||
) | |||
break | |||
oss2.resumable_upload( | |||
self.bucket, | |||
object_key, | |||
@@ -74,7 +113,9 @@ class OssUtilities: | |||
progress_callback=progress_callback, | |||
num_threads=self.upload_num_threads) | |||
break | |||
except Exception: | |||
except Exception as e: | |||
if e.__getattribute__('status') == 403: | |||
self._reload_sts() | |||
if retry_count >= self.upload_max_retries: | |||
raise | |||
@@ -5,6 +5,7 @@ from multiprocessing.dummy import Pool as ThreadPool | |||
from tqdm import tqdm | |||
from modelscope.utils.constant import UploadMode | |||
from .oss_utils import OssUtilities | |||
@@ -13,38 +14,45 @@ class DatasetUploadManager(object): | |||
def __init__(self, dataset_name: str, namespace: str, version: str): | |||
from modelscope.hub.api import HubApi | |||
_hub_api = HubApi() | |||
_cookies = _hub_api.check_cookies_upload_data(use_cookies=True) | |||
_cookies = _hub_api.check_local_cookies(use_cookies=True) | |||
_oss_config = _hub_api.get_dataset_access_config_session( | |||
cookies=_cookies, | |||
dataset_name=dataset_name, | |||
namespace=namespace, | |||
revision=version) | |||
self.oss_utilities = OssUtilities(_oss_config) | |||
self.oss_utilities = OssUtilities( | |||
oss_config=_oss_config, | |||
dataset_name=dataset_name, | |||
namespace=namespace, | |||
revision=version) | |||
def upload(self, object_name: str, local_file_path: str) -> str: | |||
def upload(self, object_name: str, local_file_path: str, | |||
upload_mode: UploadMode) -> str: | |||
object_key = self.oss_utilities.upload( | |||
oss_object_name=object_name, | |||
local_file_path=local_file_path, | |||
indicate_individual_progress=True) | |||
indicate_individual_progress=True, | |||
upload_mode=upload_mode) | |||
return object_key | |||
def upload_dir(self, object_dir_name: str, local_dir_path: str, | |||
num_processes: int, chunksize: int, | |||
filter_hidden_files: bool) -> int: | |||
filter_hidden_files: bool, upload_mode: UploadMode) -> int: | |||
def run_upload(args): | |||
self.oss_utilities.upload( | |||
oss_object_name=args[0], | |||
local_file_path=args[1], | |||
indicate_individual_progress=False) | |||
indicate_individual_progress=False, | |||
upload_mode=upload_mode) | |||
files_list = [] | |||
for root, dirs, files in os.walk(local_dir_path): | |||
for file_name in files: | |||
if filter_hidden_files and file_name.startswith('.'): | |||
continue | |||
# Concatenate directory name and relative path into a oss object key. e.g., train/001/1_1230.png | |||
# Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png | |||
object_name = os.path.join( | |||
object_dir_name, | |||
root.replace(local_dir_path, '', 1).strip('/'), file_name) | |||
@@ -541,3 +541,50 @@ class Seq2SeqLMOutput(ModelOutputBase): | |||
encoder_last_hidden_state: Optional[Tensor] = None | |||
encoder_hidden_states: Optional[Tuple[Tensor]] = None | |||
encoder_attentions: Optional[Tuple[Tensor]] = None | |||
@dataclass | |||
class TextGenerationModelOutput(ModelOutputBase): | |||
"""The output class for text generation models. | |||
Args: | |||
logits (`Tensor`): The logits output of the model. loss (`Tensor`, | |||
*optional*) The loss of the model, available when training. | |||
hidden_states (`Tensor`, *optional*) Hidden-states of the model at the | |||
output of each layer plus the optional initial embedding outputs. | |||
""" | |||
logits: Tensor = None | |||
loss: Tensor = None | |||
@dataclass | |||
class TokenGeneratorOutput(ModelOutputBase): | |||
""" | |||
The output class for generate method of text generation models. | |||
Args: | |||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |||
if all batches finished early due to the `eos_token_id`. | |||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` | |||
is passed or when `config.output_scores=True`): | |||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |||
each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. | |||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` | |||
is passed or `config.output_attentions=True`): | |||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, | |||
sequence_length)`. | |||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` | |||
is passed or when `config.output_hidden_states=True`): | |||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. | |||
""" | |||
sequences: Tensor = None | |||
scores: Optional[Tuple[Tensor]] = None | |||
attentions: Optional[Tuple[Tuple[Tensor]]] = None | |||
hidden_states: Optional[Tuple[Tuple[Tensor]]] = None |
@@ -157,7 +157,13 @@ class ReferringVideoObjectSegmentationPipeline(Pipeline): | |||
* text_border_height_per_query, 0, 0)) | |||
W, H = vid_frame.size | |||
draw = ImageDraw.Draw(vid_frame) | |||
font = ImageFont.truetype(font='DejaVuSansMono.ttf', size=30) | |||
if self.model.cfg.pipeline.output_font: | |||
font = ImageFont.truetype( | |||
font=self.model.cfg.pipeline.output_font, | |||
size=self.model.cfg.pipeline.output_font_size) | |||
else: | |||
font = ImageFont.load_default() | |||
for i, (text_query, color) in enumerate( | |||
zip(self.text_queries, colors), start=1): | |||
w, h = draw.textsize(text_query, font=font) | |||
@@ -104,6 +104,10 @@ class TextGenerationPipeline(Pipeline): | |||
tokenizer = self.preprocessor.tokenizer | |||
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||
def sentence_piece(self, inputs) -> str: | |||
tokenizer = self.preprocessor.tokenizer | |||
return tokenizer.decode(inputs.tolist()) | |||
def roberta(self, inputs) -> str: | |||
tokenizer = self.preprocessor.tokenizer | |||
decoded = tokenizer.decode(inputs.tolist()) | |||
@@ -121,7 +125,7 @@ class TextGenerationPipeline(Pipeline): | |||
Dict[str, str]: the prediction results | |||
""" | |||
inputs = inputs['sequences'] | |||
if isinstance(inputs, list): | |||
if isinstance(inputs, list) or len(inputs.shape) > 1: | |||
inputs = inputs[0] | |||
decoded = getattr(self, self.postprocessor)(inputs) | |||
text = self._remove_space_between_chinese_chars(decoded) | |||
@@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
__all__ = ['TokenClassificationPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.token_classification, module_name=Pipelines.token_classification) | |||
@PIPELINES.register_module( | |||
Tasks.token_classification, module_name=Pipelines.part_of_speech) | |||
@PIPELINES.register_module( | |||
@@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline): | |||
str) else model | |||
if preprocessor is None: | |||
preprocessor = Model.from_pretrained( | |||
preprocessor = Preprocessor.from_pretrained( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 128)) | |||
model.eval() | |||
@@ -147,8 +147,50 @@ class Preprocessor(ABC): | |||
cfg_dict: Config = None, | |||
preprocessor_mode=ModeKeys.INFERENCE, | |||
**kwargs): | |||
""" Instantiate a model from local directory or remote model repo. Note | |||
"""Instantiate a preprocessor from local directory or remote model repo. Note | |||
that when loading from remote, the model revision can be specified. | |||
Args: | |||
model_name_or_path(str): A model dir or a model id used to load the preprocessor out. | |||
revision(str, `optional`): The revision used when the model_name_or_path is | |||
a model id of the remote hub. default `master`. | |||
cfg_dict(Config, `optional`): An optional config. If provided, it will replace | |||
the config read out of the `model_name_or_path` | |||
preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, | |||
or `inference`. Default value `inference`. | |||
The preprocessor field in the config may contain two sub preprocessors: | |||
>>> { | |||
>>> "train": { | |||
>>> "type": "some-train-preprocessor" | |||
>>> }, | |||
>>> "val": { | |||
>>> "type": "some-eval-preprocessor" | |||
>>> } | |||
>>> } | |||
In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor | |||
will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class | |||
will be assigned in all the modes. | |||
Or just one: | |||
>>> { | |||
>>> "type": "some-train-preprocessor" | |||
>>> } | |||
In this scenario, the sole preprocessor will be loaded in all the modes, | |||
and the `mode` field in the preprocessor class will be assigned. | |||
**kwargs: | |||
task(str, `optional`): The `Tasks` enumeration value to replace the task value | |||
read out of config in the `model_name_or_path`. | |||
This is useful when the preprocessor does not have a `type` field and the task to be used is not | |||
equal to the task of which the model is saved. | |||
Other kwargs will be directly fed into the preprocessor, to replace the default configs. | |||
Returns: | |||
The preprocessor instance. | |||
Examples: | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') | |||
""" | |||
if not os.path.exists(model_name_or_path): | |||
model_dir = snapshot_download( | |||
@@ -157,7 +157,7 @@ class MPlugPreprocessor(Preprocessor): | |||
def image_open(self, path: str) -> Tuple[Image.Image, int]: | |||
if path not in self._image_map: | |||
index = len(self._image_map) | |||
self._image_map[path] = (Image.open(path), index) | |||
self._image_map[path] = (load_image(path), index) | |||
return self._image_map[path] | |||
def __call__( | |||
@@ -9,7 +9,8 @@ if TYPE_CHECKING: | |||
from .builder import build_trainer | |||
from .cv import (ImageInstanceSegmentationTrainer, | |||
ImagePortraitEnhancementTrainer, | |||
MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||
MovieSceneSegmentationTrainer, ImageInpaintingTrainer, | |||
ReferringVideoObjectSegmentationTrainer) | |||
from .multi_modal import CLIPTrainer | |||
from .nlp import SequenceClassificationTrainer, TextRankingTrainer | |||
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments | |||
@@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||
from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||
from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | |||
from .image_inpainting_trainer import ImageInpaintingTrainer | |||
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer | |||
else: | |||
_import_structure = { | |||
@@ -17,7 +18,9 @@ else: | |||
'image_portrait_enhancement_trainer': | |||
['ImagePortraitEnhancementTrainer'], | |||
'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | |||
'image_inpainting_trainer': ['ImageInpaintingTrainer'] | |||
'image_inpainting_trainer': ['ImageInpaintingTrainer'], | |||
'referring_video_object_segmentation_trainer': | |||
['ReferringVideoObjectSegmentationTrainer'] | |||
} | |||
import sys | |||
@@ -0,0 +1,63 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import torch | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.trainer import EpochBasedTrainer | |||
from modelscope.utils.constant import ModeKeys | |||
@TRAINERS.register_module( | |||
module_name=Trainers.referring_video_object_segmentation) | |||
class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.model.set_postprocessor(self.cfg.dataset.name) | |||
self.train_data_collator = self.train_dataset.collator | |||
self.eval_data_collator = self.eval_dataset.collator | |||
device_name = kwargs.get('device', 'gpu') | |||
self.model.set_device(self.device, device_name) | |||
def train(self, *args, **kwargs): | |||
self.model.criterion.train() | |||
super().train(*args, **kwargs) | |||
def evaluate(self, checkpoint_path=None): | |||
if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
from modelscope.trainers.hooks import CheckpointHook | |||
CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
self.model.eval() | |||
self._mode = ModeKeys.EVAL | |||
if self.eval_dataset is None: | |||
self.eval_dataloader = self.get_eval_data_loader() | |||
else: | |||
self.eval_dataloader = self._build_dataloader_with_dataset( | |||
self.eval_dataset, | |||
dist=self._dist, | |||
seed=self._seed, | |||
collate_fn=self.eval_data_collator, | |||
**self.cfg.evaluation.get('dataloader', {})) | |||
self.data_loader = self.eval_dataloader | |||
from modelscope.metrics import build_metric | |||
ann_file = self.eval_dataset.ann_file | |||
metric_classes = [] | |||
for metric in self.metrics: | |||
metric.update({'ann_file': ann_file}) | |||
metric_classes.append(build_metric(metric)) | |||
for m in metric_classes: | |||
m.trainer = self | |||
metric_values = self.evaluation_loop(self.eval_dataloader, | |||
metric_classes) | |||
self._metric_values = metric_values | |||
return metric_values | |||
def prediction_step(self, model, inputs): | |||
pass |
@@ -101,8 +101,9 @@ class CheckpointHook(Hook): | |||
model = trainer.model.module | |||
else: | |||
model = trainer.model | |||
meta = load_checkpoint(filename, model, trainer.optimizer, | |||
trainer.lr_scheduler) | |||
meta = load_checkpoint(filename, model, | |||
getattr(trainer, 'optimizer', None), | |||
getattr(trainer, 'lr_scheduler', None)) | |||
trainer._epoch = meta.get('epoch', trainer._epoch) | |||
trainer._iter = meta.get('iter', trainer._iter) | |||
trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||
@@ -111,7 +112,7 @@ class CheckpointHook(Hook): | |||
# hook: Hook | |||
key = f'{hook.__class__}-{i}' | |||
if key in meta and hasattr(hook, 'load_state_dict'): | |||
hook.load_state_dict(meta[key]) | |||
hook.load_state_dict(meta.get(key, {})) | |||
else: | |||
trainer.logger.warn( | |||
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||
@@ -123,7 +124,7 @@ class CheckpointHook(Hook): | |||
f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||
f'The saved version: {version}, runtime version: {__version__}' | |||
) | |||
trainer.logger.warn( | |||
trainer.logger.info( | |||
f'Checkpoint {filename} saving time: {meta.get("time")}') | |||
return meta | |||
@@ -171,12 +172,17 @@ class CheckpointHook(Hook): | |||
else: | |||
model = trainer.model | |||
config = trainer.cfg.to_dict() | |||
# override pipeline by tasks name after finetune done, | |||
# avoid case like fill mask pipeline with a text cls task | |||
config['pipeline'] = {'type': config['task']} | |||
if hasattr(model, 'save_pretrained'): | |||
model.save_pretrained( | |||
output_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE, | |||
save_function=save_checkpoint, | |||
config=trainer.cfg.to_dict(), | |||
config=config, | |||
with_meta=False) | |||
def after_train_iter(self, trainer): | |||
@@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .clip import CLIPTrainer | |||
from .team import TEAMImgClsTrainer | |||
else: | |||
_import_structure = {'clip': ['CLIPTrainer']} | |||
_import_structure = { | |||
'clip': ['CLIPTrainer'], | |||
'team': ['TEAMImgClsTrainer'] | |||
} | |||
import sys | |||
@@ -0,0 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .team_trainer import TEAMImgClsTrainer |
@@ -0,0 +1,144 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from collections import OrderedDict | |||
from typing import Callable, Dict, Optional | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torchvision.datasets as datasets | |||
import torchvision.transforms as transforms | |||
from sklearn.metrics import confusion_matrix | |||
from torch.optim import AdamW | |||
from torch.utils.data import DataLoader, Dataset | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.base import Model | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
get_optimizer, train_mapping, val_mapping) | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DownloadMode, ModeKeys | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@TRAINERS.register_module(module_name=Trainers.image_classification_team) | |||
class TEAMImgClsTrainer(BaseTrainer): | |||
def __init__(self, cfg_file: str, model: str, device_id: int, | |||
data_collator: Callable, train_dataset: Dataset, | |||
val_dataset: Dataset, *args, **kwargs): | |||
super().__init__(cfg_file) | |||
self.cfg = Config.from_file(cfg_file) | |||
team_model = Model.from_pretrained(model) | |||
image_model = team_model.model.image_model.vision_transformer | |||
classification_model = nn.Sequential( | |||
OrderedDict([('encoder', image_model), | |||
('classifier', | |||
nn.Linear(768, self.cfg.dataset.class_num))])) | |||
self.model = classification_model | |||
for pname, param in self.model.named_parameters(): | |||
if 'encoder' in pname: | |||
param.requires_grad = False | |||
self.device_id = device_id | |||
self.total_epoch = self.cfg.train.epoch | |||
self.train_batch_size = self.cfg.train.batch_size | |||
self.val_batch_size = self.cfg.evaluation.batch_size | |||
self.ckpt_dir = self.cfg.train.ckpt_dir | |||
self.collate_fn = data_collator | |||
self.train_dataset = train_dataset | |||
self.val_dataset = val_dataset | |||
self.criterion = nn.CrossEntropyLoss().to(self.device_id) | |||
def train(self, *args, **kwargs): | |||
self.model.train() | |||
self.model.to(self.device_id) | |||
optimizer = get_optimizer(self.model) | |||
for epoch in range(self.total_epoch): | |||
train_params = { | |||
'pin_memory': True, | |||
'collate_fn': self.collate_fn, | |||
'batch_size': self.train_batch_size, | |||
'shuffle': True, | |||
'drop_last': True, | |||
'num_workers': 8 | |||
} | |||
train_loader = DataLoader(self.train_dataset, **train_params) | |||
for batch_idx, data in enumerate(train_loader): | |||
img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
label_tensor = label_tensor.to( | |||
self.device_id, non_blocking=True) | |||
pred_logits = self.model(img_tensor) | |||
loss = self.criterion(pred_logits, label_tensor) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
if batch_idx % 10 == 0: | |||
logger.info( | |||
'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||
epoch, batch_idx, len(train_loader), loss.item())) | |||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||
torch.save(self.model.state_dict(), | |||
'{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) | |||
self.evaluate() | |||
def evaluate(self, | |||
checkpoint_path: Optional[str] = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
if checkpoint_path is not None: | |||
checkpoint_params = torch.load(checkpoint_path, 'cpu') | |||
self.model.load_state_dict(checkpoint_params) | |||
self.model.eval() | |||
self.model.to(self.device_id) | |||
val_params = { | |||
'collate_fn': self.collate_fn, | |||
'batch_size': self.val_batch_size, | |||
'shuffle': False, | |||
'drop_last': False, | |||
'num_workers': 8 | |||
} | |||
val_loader = DataLoader(self.val_dataset, **val_params) | |||
tp_cnt, processed_cnt = 0, 0 | |||
all_pred_labels, all_gt_labels = [], [] | |||
with torch.no_grad(): | |||
for batch_idx, data in enumerate(val_loader): | |||
img_tensor, label_tensor = data['pixel_values'], data['labels'] | |||
img_tensor = img_tensor.to(self.device_id, non_blocking=True) | |||
label_tensor = label_tensor.to( | |||
self.device_id, non_blocking=True) | |||
pred_logits = self.model(img_tensor) | |||
pred_labels = torch.max(pred_logits, dim=1)[1] | |||
tp_cnt += torch.sum(pred_labels == label_tensor).item() | |||
processed_cnt += img_tensor.shape[0] | |||
logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) | |||
all_pred_labels.extend(pred_labels.tolist()) | |||
all_gt_labels.extend(label_tensor.tolist()) | |||
conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) | |||
acc_mean_per_class = np.mean(conf_mat.diagonal() | |||
/ conf_mat.sum(axis=1)) | |||
logger.info( | |||
'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) |
@@ -0,0 +1,87 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torchvision.transforms as transforms | |||
from PIL import Image | |||
from torch.optim import AdamW | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
train_transforms = transforms.Compose([ | |||
transforms.RandomResizedCrop(224), | |||
transforms.RandomHorizontalFlip(), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
val_transforms = transforms.Compose([ | |||
transforms.Resize(256), | |||
transforms.CenterCrop(224), | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
def train_mapping(examples): | |||
examples['pixel_values'] = [ | |||
train_transforms(Image.open(image).convert('RGB')) | |||
for image in examples['image:FILE'] | |||
] | |||
examples['labels'] = [label for label in examples['label:LABEL']] | |||
return examples | |||
def val_mapping(examples): | |||
examples['pixel_values'] = [ | |||
val_transforms(Image.open(image).convert('RGB')) | |||
for image in examples['image:FILE'] | |||
] | |||
examples['labels'] = [label for label in examples['label:LABEL']] | |||
return examples | |||
def collate_fn(examples): | |||
images = [] | |||
labels = [] | |||
for example in examples: | |||
images.append((example['pixel_values'])) | |||
labels.append(example['labels']) | |||
pixel_values = torch.stack(images) | |||
labels = torch.tensor(labels) | |||
return {'pixel_values': pixel_values, 'labels': labels} | |||
def get_params_groups(ddp_model, lr): | |||
large_lr_params = [] | |||
small_lr_params = [] | |||
for name, param in ddp_model.named_parameters(): | |||
if not param.requires_grad: | |||
continue | |||
if 'encoder' in name: | |||
small_lr_params.append(param) | |||
elif 'classifier' in name: | |||
large_lr_params.append(param) | |||
else: | |||
logger.info('skip param: {}'.format(name)) | |||
params_groups = [{ | |||
'params': small_lr_params, | |||
'lr': lr / 10.0 | |||
}, { | |||
'params': large_lr_params, | |||
'lr': lr | |||
}] | |||
return params_groups | |||
def get_optimizer(ddp_model): | |||
lr_init = 1e-3 | |||
betas = [0.9, 0.999] | |||
weight_decay = 0.02 | |||
params_groups = get_params_groups(ddp_model, lr=lr_init) | |||
return AdamW( | |||
params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) |
@@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
break | |||
for metric_name in self.metrics: | |||
metric_values[metric_name] = np.average( | |||
[m[metric_name] for m in metric_values.values()]) | |||
all_metrics = [m[metric_name] for m in metric_values.values()] | |||
for key in all_metrics[0].keys(): | |||
metric_values[key] = np.average( | |||
[metric[key] for metric in all_metrics]) | |||
return metric_values |
@@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): | |||
return dataset | |||
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): | |||
return build_optimizer(self.model, cfg=cfg, default_args=default_args) | |||
try: | |||
return build_optimizer( | |||
self.model, cfg=cfg, default_args=default_args) | |||
except KeyError as e: | |||
self.logger.error( | |||
f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' | |||
f'please check if your torch with version: {torch.__version__} matches the config.' | |||
) | |||
raise e | |||
def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): | |||
return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
try: | |||
return build_lr_scheduler(cfg=cfg, default_args=default_args) | |||
except KeyError as e: | |||
self.logger.error( | |||
f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' | |||
f'please check if your torch with version: {torch.__version__} matches the config.' | |||
) | |||
raise e | |||
def create_optimizer_and_scheduler(self): | |||
""" Create optimizer and lr scheduler | |||
@@ -62,7 +62,10 @@ def single_gpu_test(trainer, | |||
if 'nsentences' in data: | |||
batch_size = data['nsentences'] | |||
else: | |||
batch_size = len(next(iter(data.values()))) | |||
try: | |||
batch_size = len(next(iter(data.values()))) | |||
except Exception: | |||
batch_size = data_loader.batch_size | |||
else: | |||
batch_size = len(data) | |||
for _ in range(batch_size): | |||
@@ -134,9 +134,7 @@ def load_checkpoint(filename, | |||
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||
'state_dict'] | |||
model.load_state_dict(state_dict) | |||
if 'meta' in checkpoint: | |||
return checkpoint.get('meta', {}) | |||
return checkpoint.get('meta', {}) | |||
def save_pretrained(model, | |||
@@ -238,6 +238,15 @@ class DownloadMode(enum.Enum): | |||
FORCE_REDOWNLOAD = 'force_redownload' | |||
class UploadMode(enum.Enum): | |||
""" How to upload object to remote. | |||
""" | |||
# Upload all objects from local, existing remote objects may be overwritten. (Default) | |||
OVERWRITE = 'overwrite' | |||
# Upload local objects in append mode, skipping all existing remote objects. | |||
APPEND = 'append' | |||
class DatasetFormations(enum.Enum): | |||
""" How a dataset is organized and interpreted | |||
""" | |||
@@ -87,21 +87,23 @@ class HubOperationTest(unittest.TestCase): | |||
assert mdtime1 == mdtime2 | |||
def test_download_public_without_login(self): | |||
self.prepare_case() | |||
rmtree(ModelScopeConfig.path_credential) | |||
snapshot_path = snapshot_download( | |||
model_id=self.model_id, revision=self.revision) | |||
downloaded_file_path = os.path.join(snapshot_path, | |||
download_model_file_name) | |||
assert os.path.exists(downloaded_file_path) | |||
temporary_dir = tempfile.mkdtemp() | |||
downloaded_file = model_file_download( | |||
model_id=self.model_id, | |||
file_path=download_model_file_name, | |||
revision=self.revision, | |||
cache_dir=temporary_dir) | |||
assert os.path.exists(downloaded_file) | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
try: | |||
self.prepare_case() | |||
rmtree(ModelScopeConfig.path_credential) | |||
snapshot_path = snapshot_download( | |||
model_id=self.model_id, revision=self.revision) | |||
downloaded_file_path = os.path.join(snapshot_path, | |||
download_model_file_name) | |||
assert os.path.exists(downloaded_file_path) | |||
temporary_dir = tempfile.mkdtemp() | |||
downloaded_file = model_file_download( | |||
model_id=self.model_id, | |||
file_path=download_model_file_name, | |||
revision=self.revision, | |||
cache_dir=temporary_dir) | |||
assert os.path.exists(downloaded_file) | |||
finally: | |||
self.api.login(TEST_ACCESS_TOKEN1) | |||
def test_snapshot_delete_download_cache_file(self): | |||
self.prepare_case() | |||
@@ -0,0 +1,32 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import numpy as np | |||
from modelscope.metrics.sequence_classification_metric import \ | |||
SequenceClassificationMetric | |||
from modelscope.utils.test_utils import test_level | |||
class TestTextClsMetrics(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_value(self): | |||
metric = SequenceClassificationMetric() | |||
outputs = { | |||
'logits': | |||
np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], | |||
[2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], | |||
[2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) | |||
} | |||
inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} | |||
metric.add(outputs, inputs) | |||
ret = metric.evaluate() | |||
self.assertTrue(np.isclose(ret['f1'], 0.5)) | |||
self.assertTrue(np.isclose(ret['accuracy'], 0.5)) | |||
print(ret) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,112 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import zipfile | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.test_utils import test_level | |||
logger = logging.get_logger(__name__) | |||
KEY_EXTRACTED = 'extracted' | |||
EXPECTED_MSG = 'success' | |||
class DatasetDeleteTest(unittest.TestCase): | |||
def setUp(self): | |||
self.old_dir = os.getcwd() | |||
self.dataset_name = 'small_coco_for_test' | |||
self.dataset_file_name = self.dataset_name | |||
self.prepared_dataset_name = 'pets_small' | |||
self.token = os.getenv('TEST_UPLOAD_MS_TOKEN') | |||
error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN' | |||
self.assertIsNotNone(self.token, msg=error_msg) | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.api import ModelScopeConfig | |||
self.api = HubApi() | |||
self.api.login(self.token) | |||
# get user info | |||
self.namespace, _ = ModelScopeConfig.get_user_info() | |||
self.temp_dir = tempfile.mkdtemp() | |||
self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name) | |||
if not os.path.exists(self.test_work_dir): | |||
os.makedirs(self.test_work_dir) | |||
def tearDown(self): | |||
os.chdir(self.old_dir) | |||
shutil.rmtree(self.temp_dir, ignore_errors=True) | |||
logger.info( | |||
f'Temporary directory {self.temp_dir} successfully removed!') | |||
@staticmethod | |||
def get_raw_downloaded_file_path(extracted_path): | |||
raw_downloaded_file_path = '' | |||
raw_data_dir = os.path.abspath( | |||
os.path.join(extracted_path, '../../..')) | |||
for root, dirs, files in os.walk(raw_data_dir): | |||
if KEY_EXTRACTED in dirs: | |||
for file in files: | |||
curr_file_path = os.path.join(root, file) | |||
if zipfile.is_zipfile(curr_file_path): | |||
raw_downloaded_file_path = curr_file_path | |||
return raw_downloaded_file_path | |||
def upload_test_file(self): | |||
# Get the prepared data from hub, using default modelscope namespace | |||
ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') | |||
config_res = ms_ds_train._hf_ds.config_kwargs | |||
extracted_path = config_res.get('split_config').get('train') | |||
raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path) | |||
object_name = self.dataset_file_name + '_for_del.zip' | |||
MsDataset.upload( | |||
object_name=object_name, | |||
local_file_path=raw_zipfile_path, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace) | |||
return object_name | |||
def upload_test_dir(self): | |||
ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') | |||
config_train = ms_ds_train._hf_ds.config_kwargs | |||
extracted_path_train = config_train.get('split_config').get('train') | |||
object_name = 'train_for_del' | |||
MsDataset.upload( | |||
object_name=object_name, | |||
local_file_path=os.path.join(extracted_path_train, | |||
'Pets/images/train'), | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace) | |||
return object_name + '/' | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_ds_delete_object(self): | |||
# upload prepared data | |||
file_name = self.upload_test_file() | |||
dir_name = self.upload_test_dir() | |||
# delete object | |||
del_file_msg = MsDataset.delete( | |||
object_name=file_name, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace) | |||
del_dir_msg = MsDataset.delete( | |||
object_name=dir_name, | |||
dataset_name=self.dataset_name, | |||
namespace=self.namespace) | |||
assert all([del_file_msg == EXPECTED_MSG, del_dir_msg == EXPECTED_MSG]) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -243,6 +243,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_run_with_text_to_image_synthesis_with_name(self): | |||
model = 'damo/ofa_text-to-image-synthesis_coco_large_en' | |||
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
ofa_pipe.model.generator.beam_size = 2 | |||
example = {'text': 'a bear in the water.'} | |||
result = ofa_pipe(example) | |||
result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
@@ -253,6 +254,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
model = Model.from_pretrained( | |||
'damo/ofa_text-to-image-synthesis_coco_large_en') | |||
ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) | |||
ofa_pipe.model.generator.beam_size = 2 | |||
example = {'text': 'a bear in the water.'} | |||
result = ofa_pipe(example) | |||
result[OutputKeys.OUTPUT_IMG].save('result.png') | |||
@@ -14,7 +14,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, | |||
self.task = Tasks.referring_video_object_segmentation | |||
self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_referring_video_object_segmentation(self): | |||
input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
text_queries = [ | |||
@@ -31,7 +31,7 @@ class ReferringVideoObjectSegmentationTest(unittest.TestCase, | |||
else: | |||
raise ValueError('process error') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_referring_video_object_segmentation_with_default_task(self): | |||
input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' | |||
text_queries = [ | |||
@@ -183,7 +183,7 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
task=Tasks.text_generation, model='langboat/bloom-1b4-zh') | |||
print(pipe('中国的首都是')) | |||
@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_gpt_neo(self): | |||
pipe = pipeline( | |||
task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') | |||
@@ -20,16 +20,16 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
result = tinynas_object_detection( | |||
'data/test/images/image_detection.jpg') | |||
print(result) | |||
print('airdet', result) | |||
@unittest.skip('will be enabled after damoyolo officially released') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_damoyolo(self): | |||
tinynas_object_detection = pipeline( | |||
Tasks.image_object_detection, | |||
model='damo/cv_tinynas_object-detection_damoyolo') | |||
result = tinynas_object_detection( | |||
'data/test/images/image_detection.jpg') | |||
print(result) | |||
print('damoyolo', result) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
@@ -39,7 +39,8 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def test_image_object_detection_auto_pipeline(self): | |||
test_image = 'data/test/images/image_detection.jpg' | |||
tinynas_object_detection = pipeline( | |||
Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
Tasks.image_object_detection, | |||
model='damo/cv_tinynas_object-detection_damoyolo') | |||
result = tinynas_object_detection(test_image) | |||
tinynas_object_detection.show_result(test_image, result, | |||
'demo_ret.jpg') | |||
@@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
train_datasets = [] | |||
from datasets import DownloadConfig | |||
dc = DownloadConfig() | |||
dc.local_files_only = True | |||
dc.local_files_only = False | |||
for lang in langs: | |||
train_datasets.append( | |||
load_dataset('xnli', lang, split='train', download_config=dc)) | |||
@@ -0,0 +1,101 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import zipfile | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.cv.movie_scene_segmentation import \ | |||
MovieSceneSegmentationModel | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config, ConfigDict | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||
model_id = 'damo/cv_swin-t_referring_video-object-segmentation' | |||
dataset_name = 'referring_vos_toydata' | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
cache_path = snapshot_download(self.model_id) | |||
config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) | |||
cfg = Config.from_file(config_path) | |||
max_epochs = cfg.train.max_epochs | |||
train_data_cfg = ConfigDict( | |||
name=self.dataset_name, | |||
split='train', | |||
test_mode=False, | |||
cfg=cfg.dataset) | |||
test_data_cfg = ConfigDict( | |||
name=self.dataset_name, | |||
split='test', | |||
test_mode=True, | |||
cfg=cfg.dataset) | |||
self.train_dataset = MsDataset.load( | |||
dataset_name=train_data_cfg.name, | |||
split=train_data_cfg.split, | |||
cfg=train_data_cfg.cfg, | |||
namespace='damo', | |||
test_mode=train_data_cfg.test_mode) | |||
assert next( | |||
iter(self.train_dataset.config_kwargs['split_config'].values())) | |||
self.test_dataset = MsDataset.load( | |||
dataset_name=test_data_cfg.name, | |||
split=test_data_cfg.split, | |||
cfg=test_data_cfg.cfg, | |||
namespace='damo', | |||
test_mode=test_data_cfg.test_mode) | |||
assert next( | |||
iter(self.test_dataset.config_kwargs['split_config'].values())) | |||
self.max_epochs = max_epochs | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_trainer(self): | |||
kwargs = dict( | |||
model=self.model_id, | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
work_dir='./work_dir') | |||
trainer = build_trainer( | |||
name=Trainers.referring_video_object_segmentation, | |||
default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(trainer.work_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
@unittest.skip('skip since the model is set to private for now') | |||
def test_trainer_with_model_and_args(self): | |||
cache_path = snapshot_download(self.model_id) | |||
model = MovieSceneSegmentationModel.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
work_dir='./work_dir') | |||
trainer = build_trainer( | |||
name=Trainers.referring_video_object_segmentation, | |||
default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(trainer.work_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,94 @@ | |||
import os | |||
import unittest | |||
import json | |||
import requests | |||
import torch | |||
import torch.distributed as dist | |||
import torch.multiprocessing as mp | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.trainers.multi_modal.team.team_trainer_utils import ( | |||
collate_fn, train_mapping, val_mapping) | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
logger = get_logger() | |||
def train_worker(device_id): | |||
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' | |||
ckpt_dir = './ckpt' | |||
os.makedirs(ckpt_dir, exist_ok=True) | |||
# Use epoch=1 for faster training here | |||
cfg = Config({ | |||
'framework': 'pytorch', | |||
'task': 'multi-modal-similarity', | |||
'pipeline': { | |||
'type': 'multi-modal-similarity' | |||
}, | |||
'model': { | |||
'type': 'team-multi-modal-similarity' | |||
}, | |||
'dataset': { | |||
'name': 'Caltech101', | |||
'class_num': 101 | |||
}, | |||
'preprocessor': {}, | |||
'train': { | |||
'epoch': 1, | |||
'batch_size': 32, | |||
'ckpt_dir': ckpt_dir | |||
}, | |||
'evaluation': { | |||
'batch_size': 64 | |||
} | |||
}) | |||
cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) | |||
cfg.dump(cfg_file) | |||
train_dataset = MsDataset.load( | |||
cfg.dataset.name, | |||
namespace='modelscope', | |||
split='train', | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
train_dataset = train_dataset.with_transform(train_mapping) | |||
val_dataset = MsDataset.load( | |||
cfg.dataset.name, | |||
namespace='modelscope', | |||
split='validation', | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() | |||
val_dataset = val_dataset.with_transform(val_mapping) | |||
default_args = dict( | |||
cfg_file=cfg_file, | |||
model=model_id, | |||
device_id=device_id, | |||
data_collator=collate_fn, | |||
train_dataset=train_dataset, | |||
val_dataset=val_dataset) | |||
trainer = build_trainer( | |||
name=Trainers.image_classification_team, default_args=default_args) | |||
trainer.train() | |||
trainer.evaluate() | |||
class TEAMTransferTrainerTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer(self): | |||
if torch.cuda.device_count() > 0: | |||
train_worker(device_id=0) | |||
else: | |||
train_worker(device_id=-1) | |||
logger.info('Training done') | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -119,7 +119,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
self.assertTrue(Metrics.accuracy in eval_results) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skip('skip for now before test is re-configured') | |||
def test_trainer_with_configured_datasets(self): | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
cfg: Config = read_config(model_id) | |||
@@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
trainer, 'trainer_continue_train', level='strict'): | |||
trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_evaluation(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
cache_path = snapshot_download(model_id) | |||
model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
eval_dataset=self.dataset, | |||
work_dir=self.tmp_dir) | |||
trainer = build_trainer(default_args=kwargs) | |||
print(trainer.evaluate(cache_path + '/pytorch_model.bin')) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
cache_path = snapshot_download(model_id) | |||
model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
kwargs = dict( | |||