Browse Source

[to 43878347] support batch inference in pipeline

It is recommented that each pipleine should implement `_batch`  to make a list of preprocessed data into a batched data dict.

Then  by paasing batch_size=n  we can use batch inference in pipline, for example
```python
img_captioning = pipeline(
            Tasks.image_captioning,
            model='damo/ofa_image-caption_coco_large_en')

results = img_captioning(
            [{
                'image': 'data/test/images/image_captioning.png'
            } for _ in range(6)],
            batch_size=2)
```

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10051193
master^2
wenmeng.zwm 2 years ago
parent
commit
7c0d7f872c
3 changed files with 103 additions and 5 deletions
  1. +65
    -5
      modelscope/pipelines/base.py
  2. +25
    -0
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py
  3. +13
    -0
      tests/pipelines/test_ofa_tasks.py

+ 65
- 5
modelscope/pipelines/base.py View File

@@ -160,6 +160,7 @@ class Pipeline(ABC):
# input_dict = self._handle_input(input)

# sanitize the parameters
batch_size = kwargs.pop('batch_size', None)
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
kwargs['preprocess_params'] = preprocess_params
@@ -167,9 +168,12 @@ class Pipeline(ABC):
kwargs['postprocess_params'] = postprocess_params

if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **kwargs))
if batch_size is None:
output = []
for ele in input:
output.append(self._process_single(ele, *args, **kwargs))
else:
output = self._process_batch(input, batch_size, **kwargs)

elif isinstance(input, MsDataset):
return self._process_iterator(input, *args, **kwargs)
@@ -204,6 +208,7 @@ class Pipeline(ABC):
postprocess_params = kwargs.get('postprocess_params', {})
self._check_input(input)
out = self.preprocess(input, **preprocess_params)

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
@@ -217,6 +222,55 @@ class Pipeline(ABC):
self._check_output(out)
return out

def _batch(self, data_list):
batch_data = {}
for sample_preprocessed in data_list:
for k, v in sample_preprocessed.items():
value_list = batch_data.get(k, [])
value_list.append(v)
batch_data[k] = value_list
for k in batch_data.keys():
if isinstance(batch_data[k][0], torch.Tensor):
batch_data[k] = torch.concat(batch_data[k])
return batch_data

def _process_batch(self, input: List[Input], batch_size,
**kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

# batch data
batched_input = {}
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
if self._auto_collate:
out = self._batch(preprocessed_list)
batched_out = self._collate_fn(out)
batched_out = self.forward(batched_out,
**forward_params)
else:
batched_out = self.forward(batched_input, **forward_params)
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
out[k] = element[batch_idx]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)

return output_list

def _check_input(self, input):
task_name = self.group_key
if task_name in TASK_INPUTS:
@@ -290,12 +344,14 @@ class Pipeline(ABC):
return self.model(inputs, **forward_params)

@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess
code should be write here.

Args:
inputs: input data
post_params: post process parameters

Return:
dict of results: a dict containing outputs of model, each
@@ -429,7 +485,11 @@ def collate_fn(data, device):
from torch.utils.data.dataloader import default_collate
from modelscope.preprocessors.nlp import InputFeatures
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
# add compatibility for img_metas for mmlab models
return type(data)({
k: collate_fn(v, device) if k != 'img_metas' else v
for k, v in data.items()
})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])


+ 25
- 0
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -46,6 +46,31 @@ class ImageCaptioningPipeline(Pipeline):
preprocessor = MPlugPreprocessor(pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def _batch(self, data):
if isinstance(self.model, OfaForAllTasks):
# collate batch data due to the nested data structure
if isinstance(data, list):
batch_data = {}
batch_data['nsentences'] = len(data)
batch_data['samples'] = [d['samples'][0] for d in data]
batch_data['net_input'] = {}
for k in data[0]['net_input'].keys():
batch_data['net_input'][k] = torch.concat(
[d['net_input'][k] for d in data])

return batch_data
elif isinstance(self.model, MPlugForAllTasks):
from transformers.tokenization_utils_base import BatchEncoding
batch_data = dict(train=data[0]['train'])
batch_data['image'] = torch.concat([d['image'] for d in data])
question = {}
for k in data[0]['question'].keys():
question[k] = torch.concat([d['question'][k] for d in data])
batch_data['question'] = BatchEncoding(question)
return batch_data
else:
return super()._collate_batch(data)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():


+ 13
- 0
tests/pipelines/test_ofa_tasks.py View File

@@ -45,6 +45,19 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
result = img_captioning('data/test/images/image_captioning.png')
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_image_captioning_batch(self):
img_captioning = pipeline(
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_large_en')
results = img_captioning(
[{
'image': 'data/test/images/image_captioning.png'
} for _ in range(6)],
batch_size=2)
for r in results:
print(r[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_ocr_recognize_with_name(self):
ocr_recognize = pipeline(


Loading…
Cancel
Save