You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compute_utils.py 833 B

1234567891011121314151617181920212223242526272829303132
  1. import torch
  2. def one_hot_encoder(input_tensor, n_classes):
  3. """
  4. 将输入tensor转化为one-hot形式
  5. :param input_tensor:
  6. :param n_classes:
  7. :return:
  8. """
  9. tensor_list = []
  10. for i in range(n_classes):
  11. temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
  12. tensor_list.append(temp_prob.unsqueeze(1))
  13. output_tensor = torch.cat(tensor_list, dim=1)
  14. return output_tensor.long()
  15. def torch_nanmean(x):
  16. """
  17. 输出忽略nan的tensor均值
  18. :param x:
  19. :return:
  20. """
  21. num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum()
  22. value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum()
  23. # num为0表示均为nan, 此时由于分母不能为0, 则设num为1
  24. if num == 0:
  25. num = 1
  26. return value / num

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)