|
1234567891011121314151617181920212223242526272829303132 |
- import torch
-
-
- def one_hot_encoder(input_tensor, n_classes):
- """
- 将输入tensor转化为one-hot形式
-
- :param input_tensor:
- :param n_classes:
- :return:
- """
- tensor_list = []
- for i in range(n_classes):
- temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
- tensor_list.append(temp_prob.unsqueeze(1))
- output_tensor = torch.cat(tensor_list, dim=1)
- return output_tensor.long()
-
-
- def torch_nanmean(x):
- """
- 输出忽略nan的tensor均值
-
- :param x:
- :return:
- """
- num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum()
- value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum()
- # num为0表示均为nan, 此时由于分母不能为0, 则设num为1
- if num == 0:
- num = 1
- return value / num
|