Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10885659master^2
| @@ -224,8 +224,8 @@ class BodyKeypointsDetection3D(TorchModel): | |||||
| lst_pose2d_cannoical.append(pose2d_canonical[:, | lst_pose2d_cannoical.append(pose2d_canonical[:, | ||||
| i - pad:i + pad + 1]) | i - pad:i + pad + 1]) | ||||
| input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_rr = torch.cat(lst_pose2d_cannoical, axis=0) | |||||
| input_pose2d_cannoical = torch.cat(lst_pose2d_cannoical, axis=0) | |||||
| if self.cfg.model.MODEL.USE_CANONICAL_COORDS: | if self.cfg.model.MODEL.USE_CANONICAL_COORDS: | ||||
| input_pose2d_abs = input_pose2d_cannoical.clone() | input_pose2d_abs = input_pose2d_cannoical.clone() | ||||
| @@ -730,7 +730,7 @@ def make_msa_feat_v2(batch): | |||||
| batch['cluster_profile'], | batch['cluster_profile'], | ||||
| deletion_mean_value, | deletion_mean_value, | ||||
| ] | ] | ||||
| batch['msa_feat'] = torch.concat(msa_feat, dim=-1) | |||||
| batch['msa_feat'] = torch.cat(msa_feat, dim=-1) | |||||
| return batch | return batch | ||||
| @@ -1320,7 +1320,7 @@ def get_contiguous_crop_idx( | |||||
| asym_offset + this_start + csz)) | asym_offset + this_start + csz)) | ||||
| asym_offset += ll | asym_offset += ll | ||||
| return torch.concat(crop_idxs) | |||||
| return torch.cat(crop_idxs) | |||||
| def get_spatial_crop_idx( | def get_spatial_crop_idx( | ||||
| @@ -217,7 +217,7 @@ class MSAAttention(nn.Module): | |||||
| if mask is not None else None) | if mask is not None else None) | ||||
| outputs.append( | outputs.append( | ||||
| self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | ||||
| return torch.concat(outputs, dim=-3) | |||||
| return torch.cat(outputs, dim=-3) | |||||
| def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | ||||
| m = self.layer_norm_m(m) | m = self.layer_norm_m(m) | ||||
| @@ -233,7 +233,7 @@ class Pipeline(ABC): | |||||
| batch_data[k] = value_list | batch_data[k] = value_list | ||||
| for k in batch_data.keys(): | for k in batch_data.keys(): | ||||
| if isinstance(batch_data[k][0], torch.Tensor): | if isinstance(batch_data[k][0], torch.Tensor): | ||||
| batch_data[k] = torch.concat(batch_data[k]) | |||||
| batch_data[k] = torch.cat(batch_data[k]) | |||||
| return batch_data | return batch_data | ||||
| def _process_batch(self, input: List[Input], batch_size, | def _process_batch(self, input: List[Input], batch_size, | ||||
| @@ -46,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline): | |||||
| batch_data['samples'] = [d['samples'][0] for d in data] | batch_data['samples'] = [d['samples'][0] for d in data] | ||||
| batch_data['net_input'] = {} | batch_data['net_input'] = {} | ||||
| for k in data[0]['net_input'].keys(): | for k in data[0]['net_input'].keys(): | ||||
| batch_data['net_input'][k] = torch.concat( | |||||
| batch_data['net_input'][k] = torch.cat( | |||||
| [d['net_input'][k] for d in data]) | [d['net_input'][k] for d in data]) | ||||
| return batch_data | return batch_data | ||||
| elif isinstance(self.model, MPlugForAllTasks): | elif isinstance(self.model, MPlugForAllTasks): | ||||
| from transformers.tokenization_utils_base import BatchEncoding | from transformers.tokenization_utils_base import BatchEncoding | ||||
| batch_data = dict(train=data[0]['train']) | batch_data = dict(train=data[0]['train']) | ||||
| batch_data['image'] = torch.concat([d['image'] for d in data]) | |||||
| batch_data['image'] = torch.cat([d['image'] for d in data]) | |||||
| question = {} | question = {} | ||||
| for k in data[0]['question'].keys(): | for k in data[0]['question'].keys(): | ||||
| question[k] = torch.concat([d['question'][k] for d in data]) | |||||
| question[k] = torch.cat([d['question'][k] for d in data]) | |||||
| batch_data['question'] = BatchEncoding(question) | batch_data['question'] = BatchEncoding(question) | ||||
| return batch_data | return batch_data | ||||
| else: | else: | ||||