Browse Source

fix device mis match

master
行嗔 3 years ago
parent
commit
7ccf40b625
2 changed files with 3 additions and 15 deletions
  1. +3
    -2
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +0
    -13
      tests/pipelines/test_ofa_tasks.py

+ 3
- 2
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -187,13 +187,14 @@ class OfaForAllTasks(TorchModel):
valid_size = len(val_ans)
valid_tgt_items = [
torch.cat([
torch.tensor(decoder_prompt[1:]), valid_answer,
torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer,
self.eos_item
]) for decoder_prompt in input['decoder_prompts']
for valid_answer in val_ans
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
torch.cat(
[torch.tensor(decoder_prompt).to('cpu'), valid_answer])
for decoder_prompt in input['decoder_prompts']
for valid_answer in val_ans
]


+ 0
- 13
tests/pipelines/test_ofa_tasks.py View File

@@ -37,19 +37,6 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = img_captioning({'image': image})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_zh_with_model(self):
model = Model.from_pretrained(
'/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_base_zh'
)
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
)
image = 'data/test/images/image_captioning.png'
result = img_captioning({'image': image})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_image_captioning_with_name(self):
img_captioning = pipeline(


Loading…
Cancel
Save