@@ -160,6 +160,7 @@ class Pipeline(ABC):
# input_dict = self._handle_input(input)
# sanitize the parameters
batch_size = kwargs.pop('batch_size', None)
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
kwargs['preprocess_params'] = preprocess_params
@@ -167,9 +168,12 @@ class Pipeline(ABC):
kwargs['postprocess_params'] = postprocess_params
if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **kwargs))
if batch_size is None:
output = []
for ele in input:
output.append(self._process_single(ele, *args, **kwargs))
else:
output = self._process_batch(input, batch_size, **kwargs)
elif isinstance(input, MsDataset):
return self._process_iterator(input, *args, **kwargs)
@@ -204,6 +208,7 @@ class Pipeline(ABC):
postprocess_params = kwargs.get('postprocess_params', {})
self._check_input(input)
out = self.preprocess(input, **preprocess_params)
with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
@@ -217,6 +222,55 @@ class Pipeline(ABC):
self._check_output(out)
return out
def _batch(self, data_list):
batch_data = {}
for sample_preprocessed in data_list:
for k, v in sample_preprocessed.items():
value_list = batch_data.get(k, [])
value_list.append(v)
batch_data[k] = value_list
for k in batch_data.keys():
if isinstance(batch_data[k][0], torch.Tensor):
batch_data[k] = torch.concat(batch_data[k])
return batch_data
def _process_batch(self, input: List[Input], batch_size,
**kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')
# batch data
batched_input = {}
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]
with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
if self._auto_collate:
out = self._batch(preprocessed_list)
batched_out = self._collate_fn(out)
batched_out = self.forward(batched_out,
**forward_params)
else:
batched_out = self.forward(batched_input, **forward_params)
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
out[k] = element[batch_idx]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
return output_list
def _check_input(self, input):
task_name = self.group_key
if task_name in TASK_INPUTS:
@@ -290,12 +344,14 @@ class Pipeline(ABC):
return self.model(inputs, **forward_params)
@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess
code should be write here.
Args:
inputs: input data
post_params: post process parameters
Return:
dict of results: a dict containing outputs of model, each
@@ -429,7 +485,11 @@ def collate_fn(data, device):
from torch.utils.data.dataloader import default_collate
from modelscope.preprocessors.nlp import InputFeatures
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
# add compatibility for img_metas for mmlab models
return type(data)({
k: collate_fn(v, device) if k != 'img_metas' else v
for k, v in data.items()
})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])