|
- import torch
-
- cls_score_original = torch.rand(8,11,9).cuda()
- print(cls_score_original)
- cls_score = cls_score_original.permute(0, 2, 1).reshape(-1, 11).contiguous()
- print(cls_score)
- cls_score = cls_score.reshape(8,9,11).permute(0, 2, 1)
- print(cls_score)
- print(cls_score_original == cls_score)
|