You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # 包含一些与网络无关的工具
  2. import glob
  3. import os
  4. import random
  5. import zipfile
  6. import cv2
  7. import torch
  8. def get_dataset_list(dataset_path):
  9. if not os.path.exists(dataset_path + '/dataset_list.txt'):
  10. all_list = glob.glob(dataset_path + '/labels' + '/*.png')
  11. random.shuffle(all_list)
  12. all_list = [os.path.basename(item.replace('\\', '/')) for item in all_list]
  13. written = all_list
  14. with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
  15. for line in written:
  16. f.write(line + '\n')
  17. print('已生成新的数据list')
  18. return all_list
  19. else:
  20. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  21. return all_list
  22. def zip_dir(dir_path, zip_path):
  23. """
  24. 压缩文件
  25. :param dir_path: 目标文件夹路径
  26. :param zip_path: 压缩后的文件夹路径
  27. """
  28. ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
  29. for root, dirnames, filenames in os.walk(dir_path):
  30. file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩
  31. # 循环出一个个文件名
  32. for filename in filenames:
  33. ziper.write(os.path.join(root, filename), os.path.join(file_path, filename))
  34. ziper.close()
  35. def ncolors(num_colors):
  36. """
  37. 生成区别度较大的几种颜色
  38. copy: https://blog.csdn.net/choumin/article/details/90320297
  39. :param num_colors: 颜色数
  40. :return:
  41. """
  42. def get_n_hls_colors(num):
  43. import random
  44. hls_colors = []
  45. i = 0
  46. step = 360.0 / num
  47. while i < 360:
  48. h = i
  49. s = 90 + random.random() * 10
  50. li = 50 + random.random() * 10
  51. _hlsc = [h / 360.0, li / 100.0, s / 100.0]
  52. hls_colors.append(_hlsc)
  53. i += step
  54. return hls_colors
  55. import colorsys
  56. rgb_colors = []
  57. if num_colors < 1:
  58. return rgb_colors
  59. for hlsc in get_n_hls_colors(num_colors):
  60. _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
  61. r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
  62. rgb_colors.append([r, g, b])
  63. return rgb_colors
  64. def visual_label(dataset_path, n_classes):
  65. """
  66. 将标签可视化
  67. :param dataset_path: 地址
  68. :param n_classes: 类别数
  69. """
  70. label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/')
  71. label_image_list = glob.glob(label_path + '/*.png')
  72. label_image_list.sort()
  73. from torchvision import transforms
  74. trans_factory = transforms.ToPILImage()
  75. if not os.path.exists(dataset_path + '/visual_label'):
  76. os.mkdir(dataset_path + '/visual_label')
  77. for index in range(len(label_image_list)):
  78. label_image = cv2.imread(label_image_list[index], -1)
  79. name = os.path.basename(label_image_list[index])
  80. trans_factory(torch.from_numpy(label_image).float() / n_classes).save(
  81. dataset_path + '/visual_label/' + name,
  82. quality=95)
  83. def get_ckpt_path(version_nth: int, kth_fold: int):
  84. if version_nth is None:
  85. return None
  86. else:
  87. version_name = f'version_{version_nth + kth_fold}'
  88. checkpoints_path = './logs/default/' + version_name + '/checkpoints'
  89. ckpt_path = glob.glob(checkpoints_path + '/*.ckpt')
  90. return ckpt_path[0].replace('\\', '/')
  91. def rwxl():
  92. # 写
  93. # dataset_xl = xl.Workbook(write_only=True)
  94. # dataset_sh = dataset_xl.create_sheet('dataset', 0)
  95. # for row in range(self.x.shape[0]):
  96. # for col in range(self.x.shape[1]):
  97. # dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col])
  98. # dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row])
  99. # dataset_xl.save(dataset_path + '/dataset.xlsx')
  100. # dataset_xl.close()
  101. # 读
  102. # dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True)
  103. # dataset_sh = dataset_xl.get_sheet_by_name('dataset_list')
  104. # temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in
  105. # range(config['dataset_len'])]
  106. # dataset_xl.close()
  107. pass
  108. if __name__ == "__main__":
  109. get_ckpt_path('version_0')

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)