Browse Source

fix: torch.concat compatibility with torch1.8

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10885659
master^2
wenmeng.zwm 3 years ago
parent
commit
3b78421236
5 changed files with 9 additions and 9 deletions
  1. +2
    -2
      modelscope/models/cv/body_3d_keypoints/body_3d_pose.py
  2. +2
    -2
      modelscope/models/science/unifold/data/data_ops.py
  3. +1
    -1
      modelscope/models/science/unifold/modules/attentions.py
  4. +1
    -1
      modelscope/pipelines/base.py
  5. +3
    -3
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py

+ 2
- 2
modelscope/models/cv/body_3d_keypoints/body_3d_pose.py View File

@@ -224,8 +224,8 @@ class BodyKeypointsDetection3D(TorchModel):
lst_pose2d_cannoical.append(pose2d_canonical[:,
i - pad:i + pad + 1])

input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0)
input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0)
input_pose2d_rr = torch.cat(lst_pose2d_cannoical, axis=0)
input_pose2d_cannoical = torch.cat(lst_pose2d_cannoical, axis=0)

if self.cfg.model.MODEL.USE_CANONICAL_COORDS:
input_pose2d_abs = input_pose2d_cannoical.clone()


+ 2
- 2
modelscope/models/science/unifold/data/data_ops.py View File

@@ -730,7 +730,7 @@ def make_msa_feat_v2(batch):
batch['cluster_profile'],
deletion_mean_value,
]
batch['msa_feat'] = torch.concat(msa_feat, dim=-1)
batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
return batch


@@ -1320,7 +1320,7 @@ def get_contiguous_crop_idx(
asym_offset + this_start + csz))
asym_offset += ll

return torch.concat(crop_idxs)
return torch.cat(crop_idxs)


def get_spatial_crop_idx(


+ 1
- 1
modelscope/models/science/unifold/modules/attentions.py View File

@@ -217,7 +217,7 @@ class MSAAttention(nn.Module):
if mask is not None else None)
outputs.append(
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
return torch.concat(outputs, dim=-3)
return torch.cat(outputs, dim=-3)

def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
m = self.layer_norm_m(m)


+ 1
- 1
modelscope/pipelines/base.py View File

@@ -233,7 +233,7 @@ class Pipeline(ABC):
batch_data[k] = value_list
for k in batch_data.keys():
if isinstance(batch_data[k][0], torch.Tensor):
batch_data[k] = torch.concat(batch_data[k])
batch_data[k] = torch.cat(batch_data[k])
return batch_data

def _process_batch(self, input: List[Input], batch_size,


+ 3
- 3
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -46,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline):
batch_data['samples'] = [d['samples'][0] for d in data]
batch_data['net_input'] = {}
for k in data[0]['net_input'].keys():
batch_data['net_input'][k] = torch.concat(
batch_data['net_input'][k] = torch.cat(
[d['net_input'][k] for d in data])

return batch_data
elif isinstance(self.model, MPlugForAllTasks):
from transformers.tokenization_utils_base import BatchEncoding
batch_data = dict(train=data[0]['train'])
batch_data['image'] = torch.concat([d['image'] for d in data])
batch_data['image'] = torch.cat([d['image'] for d in data])
question = {}
for k in data[0]['question'].keys():
question[k] = torch.concat([d['question'][k] for d in data])
question[k] = torch.cat([d['question'][k] for d in data])
batch_data['question'] = BatchEncoding(question)
return batch_data
else:


Loading…
Cancel
Save