|
|
|
@@ -46,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline): |
|
|
|
batch_data['samples'] = [d['samples'][0] for d in data] |
|
|
|
batch_data['net_input'] = {} |
|
|
|
for k in data[0]['net_input'].keys(): |
|
|
|
batch_data['net_input'][k] = torch.concat( |
|
|
|
batch_data['net_input'][k] = torch.cat( |
|
|
|
[d['net_input'][k] for d in data]) |
|
|
|
|
|
|
|
return batch_data |
|
|
|
elif isinstance(self.model, MPlugForAllTasks): |
|
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
batch_data = dict(train=data[0]['train']) |
|
|
|
batch_data['image'] = torch.concat([d['image'] for d in data]) |
|
|
|
batch_data['image'] = torch.cat([d['image'] for d in data]) |
|
|
|
question = {} |
|
|
|
for k in data[0]['question'].keys(): |
|
|
|
question[k] = torch.concat([d['question'][k] for d in data]) |
|
|
|
question[k] = torch.cat([d['question'][k] for d in data]) |
|
|
|
batch_data['question'] = BatchEncoding(question) |
|
|
|
return batch_data |
|
|
|
else: |
|
|
|
|