Browse Source

Merge branch 'master' into nlp/space/dst

master
ly119399 3 years ago
parent
commit
9e1c7d54d5
2 changed files with 14 additions and 8 deletions
  1. +13
    -7
      modelscope/models/multi_modal/image_captioning_model.py
  2. +1
    -1
      tests/pipelines/test_animal_recognation.py

+ 13
- 7
modelscope/models/multi_modal/image_captioning_model.py View File

@@ -1,6 +1,7 @@
import os.path as osp
from typing import Any, Dict

import torch.cuda
from PIL import Image

from modelscope.metainfo import Models
@@ -26,9 +27,13 @@ class OfaForImageCaptioning(Model):
self.eval_caption = eval_caption

tasks.register_task('caption', CaptionTask)
use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False
use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\
else False
overrides = {
'bpe_dir': bpe_dir,
'eval_cider': False,
@@ -39,13 +44,11 @@ class OfaForImageCaptioning(Model):
}
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides)

# Move models to GPU
for model in models:
model.eval()
if use_cuda:
model.cuda()
if use_fp16:
model.to(self._device)
if self.use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
@@ -68,6 +71,9 @@ class OfaForImageCaptioning(Model):
self.task = task

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
import fairseq.utils
if torch.cuda.is_available():
input = fairseq.utils.move_to_cuda(input, device=self._device)
results, _ = self.eval_caption(self.task, self.generator, self.models,
input)
return {


+ 1
- 1
tests/pipelines/test_animal_recognation.py View File

@@ -11,7 +11,7 @@ class MultiModalFeatureTest(unittest.TestCase):
def test_run(self):
animal_recog = pipeline(
Tasks.image_classification,
model='damo/cv_resnest101_animal_recognation')
model='damo/cv_resnest101_animal_recognition')
result = animal_recog('data/test/images/image1.jpg')
print(result)



Loading…
Cancel
Save