Browse Source

fix: torch.concat compatibility with torch1.8

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10885659
master^2
wenmeng.zwm 3 years ago
parent
commit
3b78421236
5 changed files with 9 additions and 9 deletions
  1. +2
    -2
      modelscope/models/cv/body_3d_keypoints/body_3d_pose.py
  2. +2
    -2
      modelscope/models/science/unifold/data/data_ops.py
  3. +1
    -1
      modelscope/models/science/unifold/modules/attentions.py
  4. +1
    -1
      modelscope/pipelines/base.py
  5. +3
    -3
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py

+ 2
- 2
modelscope/models/cv/body_3d_keypoints/body_3d_pose.py View File

@@ -224,8 +224,8 @@ class BodyKeypointsDetection3D(TorchModel):
lst_pose2d_cannoical.append(pose2d_canonical[:, lst_pose2d_cannoical.append(pose2d_canonical[:,
i - pad:i + pad + 1]) i - pad:i + pad + 1])


input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0)
input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0)
input_pose2d_rr = torch.cat(lst_pose2d_cannoical, axis=0)
input_pose2d_cannoical = torch.cat(lst_pose2d_cannoical, axis=0)


if self.cfg.model.MODEL.USE_CANONICAL_COORDS: if self.cfg.model.MODEL.USE_CANONICAL_COORDS:
input_pose2d_abs = input_pose2d_cannoical.clone() input_pose2d_abs = input_pose2d_cannoical.clone()


+ 2
- 2
modelscope/models/science/unifold/data/data_ops.py View File

@@ -730,7 +730,7 @@ def make_msa_feat_v2(batch):
batch['cluster_profile'], batch['cluster_profile'],
deletion_mean_value, deletion_mean_value,
] ]
batch['msa_feat'] = torch.concat(msa_feat, dim=-1)
batch['msa_feat'] = torch.cat(msa_feat, dim=-1)
return batch return batch




@@ -1320,7 +1320,7 @@ def get_contiguous_crop_idx(
asym_offset + this_start + csz)) asym_offset + this_start + csz))
asym_offset += ll asym_offset += ll


return torch.concat(crop_idxs)
return torch.cat(crop_idxs)




def get_spatial_crop_idx( def get_spatial_crop_idx(


+ 1
- 1
modelscope/models/science/unifold/modules/attentions.py View File

@@ -217,7 +217,7 @@ class MSAAttention(nn.Module):
if mask is not None else None) if mask is not None else None)
outputs.append( outputs.append(
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
return torch.concat(outputs, dim=-3)
return torch.cat(outputs, dim=-3)


def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
m = self.layer_norm_m(m) m = self.layer_norm_m(m)


+ 1
- 1
modelscope/pipelines/base.py View File

@@ -233,7 +233,7 @@ class Pipeline(ABC):
batch_data[k] = value_list batch_data[k] = value_list
for k in batch_data.keys(): for k in batch_data.keys():
if isinstance(batch_data[k][0], torch.Tensor): if isinstance(batch_data[k][0], torch.Tensor):
batch_data[k] = torch.concat(batch_data[k])
batch_data[k] = torch.cat(batch_data[k])
return batch_data return batch_data


def _process_batch(self, input: List[Input], batch_size, def _process_batch(self, input: List[Input], batch_size,


+ 3
- 3
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -46,17 +46,17 @@ class ImageCaptioningPipeline(Pipeline):
batch_data['samples'] = [d['samples'][0] for d in data] batch_data['samples'] = [d['samples'][0] for d in data]
batch_data['net_input'] = {} batch_data['net_input'] = {}
for k in data[0]['net_input'].keys(): for k in data[0]['net_input'].keys():
batch_data['net_input'][k] = torch.concat(
batch_data['net_input'][k] = torch.cat(
[d['net_input'][k] for d in data]) [d['net_input'][k] for d in data])


return batch_data return batch_data
elif isinstance(self.model, MPlugForAllTasks): elif isinstance(self.model, MPlugForAllTasks):
from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_base import BatchEncoding
batch_data = dict(train=data[0]['train']) batch_data = dict(train=data[0]['train'])
batch_data['image'] = torch.concat([d['image'] for d in data])
batch_data['image'] = torch.cat([d['image'] for d in data])
question = {} question = {}
for k in data[0]['question'].keys(): for k in data[0]['question'].keys():
question[k] = torch.concat([d['question'][k] for d in data])
question[k] = torch.cat([d['question'][k] for d in data])
batch_data['question'] = BatchEncoding(question) batch_data['question'] = BatchEncoding(question)
return batch_data return batch_data
else: else:


Loading…
Cancel
Save