You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # 包含一些与网络无关的工具
  2. import glob
  3. import os
  4. import random
  5. import zipfile
  6. import cv2
  7. import torch
  8. def divide_dataset(dataset_path, rate_datasets):
  9. """
  10. 切分数据集, 划分为训练集,验证集,测试集生成list文件并保存为:
  11. train_dataset_list、validate_dataset_list、test_dataset_list.
  12. 每个比例必须大于0且保证至少每个数据集中具有一个样本, 验证集可以为0.
  13. :param dataset_path: 数据集的地址
  14. :param rate_datasets: 不同数据集[训练集,验证集,测试集]的比例
  15. """
  16. # 当不存在总的all_dataset_list文件时, 生成all_dataset_list
  17. if not os.path.exists(dataset_path + '/all_dataset_list.txt'):
  18. all_list = glob.glob(dataset_path + '/labels' + '/*.png')
  19. with open(dataset_path + '/all_dataset_list.txt', 'w', encoding='utf-8') as f:
  20. for line in all_list:
  21. f.write(os.path.basename(line.replace('\\', '/')) + '\n')
  22. path_train_dataset_list = dataset_path + '/train_dataset_list.txt'
  23. path_validate_dataset_list = dataset_path + '/validate_dataset_list.txt'
  24. path_test_dataset_list = dataset_path + '/test_dataset_list.txt'
  25. # 如果验证集的比例为0,则将测试集设置为验证集并取消测试集;
  26. if rate_datasets[1] == 0:
  27. # 如果无切分后的list文件, 则生成新的list文件
  28. if not (os.path.exists(path_train_dataset_list) and
  29. os.path.exists(path_validate_dataset_list) and
  30. os.path.exists(path_test_dataset_list)):
  31. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  32. random.shuffle(all_list)
  33. train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
  34. test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
  35. with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
  36. for line in train_dataset_list:
  37. f.write(line)
  38. with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
  39. for line in test_dataset_list:
  40. f.write(line)
  41. with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
  42. for line in test_dataset_list:
  43. f.write(line)
  44. print('已生成新的数据list')
  45. else:
  46. # 判断比例是否正确,如果不正确,则重新生成数据集
  47. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  48. with open(path_train_dataset_list) as f:
  49. train_dataset_list_exist = f.readlines()
  50. with open(path_validate_dataset_list) as f:
  51. test_dataset_list_exist = f.readlines()
  52. random.shuffle(all_list)
  53. train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
  54. test_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):]
  55. if not (len(train_dataset_list_exist) == len(train_dataset_list) and
  56. len(test_dataset_list_exist) == len(test_dataset_list)):
  57. with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
  58. for line in train_dataset_list:
  59. f.write(line)
  60. with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
  61. for line in test_dataset_list:
  62. f.write(line)
  63. with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
  64. for line in test_dataset_list:
  65. f.write(line)
  66. print('已生成新的数据list')
  67. # 如果验证集比例不为零,则同时存在验证集和测试集
  68. else:
  69. # 如果无切分后的list文件, 则生成新的list文件
  70. if not (os.path.exists(dataset_path + '/train_dataset_list.txt') and
  71. os.path.exists(dataset_path + '/validate_dataset_list.txt') and
  72. os.path.exists(dataset_path + '/test_dataset_list.txt')):
  73. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  74. random.shuffle(all_list)
  75. train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
  76. validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
  77. int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
  78. test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
  79. with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
  80. for line in train_dataset_list:
  81. f.write(line)
  82. with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
  83. for line in validate_dataset_list:
  84. f.write(line)
  85. with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
  86. for line in test_dataset_list:
  87. f.write(line)
  88. print('已生成新的数据list')
  89. else:
  90. # 判断比例是否正确,如果不正确,则重新生成数据集
  91. all_list = open(dataset_path + '/all_dataset_list.txt').readlines()
  92. with open(path_train_dataset_list) as f:
  93. train_dataset_list_exist = f.readlines()
  94. with open(path_validate_dataset_list) as f:
  95. validate_dataset_list_exist = f.readlines()
  96. with open(path_test_dataset_list) as f:
  97. test_dataset_list_exist = f.readlines()
  98. random.shuffle(all_list)
  99. train_dataset_list = all_list[0:int(len(all_list) * rate_datasets[0])]
  100. validate_dataset_list = all_list[int(len(all_list) * rate_datasets[0]):
  101. int(len(all_list) * (rate_datasets[0] + rate_datasets[1]))]
  102. test_dataset_list = all_list[int(len(all_list) * (rate_datasets[0] + rate_datasets[1])):]
  103. if not (len(train_dataset_list_exist) == len(train_dataset_list) and
  104. len(validate_dataset_list_exist) == len(validate_dataset_list) and
  105. len(test_dataset_list_exist) == len(test_dataset_list)):
  106. with open(path_train_dataset_list, 'w', encoding='utf-8') as f:
  107. for line in train_dataset_list:
  108. f.write(line)
  109. with open(path_validate_dataset_list, 'w', encoding='utf-8') as f:
  110. for line in validate_dataset_list:
  111. f.write(line)
  112. with open(path_test_dataset_list, 'w', encoding='utf-8') as f:
  113. for line in test_dataset_list:
  114. f.write(line)
  115. print('已生成新的数据list')
  116. def zip_dir(dir_path, zip_path):
  117. """
  118. 压缩文件
  119. :param dir_path: 目标文件夹路径
  120. :param zip_path: 压缩后的文件夹路径
  121. """
  122. ziper = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
  123. for root, dirnames, filenames in os.walk(dir_path):
  124. file_path = root.replace(dir_path, '') # 去掉根路径,只对目标文件夹下的文件及文件夹进行压缩
  125. # 循环出一个个文件名
  126. for filename in filenames:
  127. ziper.write(os.path.join(root, filename), os.path.join(file_path, filename))
  128. ziper.close()
  129. def ncolors(num_colors):
  130. """
  131. 生成区别度较大的几种颜色
  132. copy: https://blog.csdn.net/choumin/article/details/90320297
  133. :param num_colors: 颜色数
  134. :return:
  135. """
  136. def get_n_hls_colors(num):
  137. import random
  138. hls_colors = []
  139. i = 0
  140. step = 360.0 / num
  141. while i < 360:
  142. h = i
  143. s = 90 + random.random() * 10
  144. li = 50 + random.random() * 10
  145. _hlsc = [h / 360.0, li / 100.0, s / 100.0]
  146. hls_colors.append(_hlsc)
  147. i += step
  148. return hls_colors
  149. import colorsys
  150. rgb_colors = []
  151. if num_colors < 1:
  152. return rgb_colors
  153. for hlsc in get_n_hls_colors(num_colors):
  154. _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
  155. r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
  156. rgb_colors.append([r, g, b])
  157. return rgb_colors
  158. def visual_label(dataset_path, n_classes):
  159. """
  160. 将标签可视化
  161. :param dataset_path: 地址
  162. :param n_classes: 类别数
  163. """
  164. label_path = os.path.join(dataset_path, 'test', 'labels').replace('\\', '/')
  165. label_image_list = glob.glob(label_path + '/*.png')
  166. label_image_list.sort()
  167. from torchvision import transforms
  168. trans_factory = transforms.ToPILImage()
  169. if not os.path.exists(dataset_path + '/visual_label'):
  170. os.mkdir(dataset_path + '/visual_label')
  171. for index in range(len(label_image_list)):
  172. label_image = cv2.imread(label_image_list[index], -1)
  173. name = os.path.basename(label_image_list[index])
  174. trans_factory(torch.from_numpy(label_image).float() / n_classes).save(
  175. dataset_path + '/visual_label/' + name,
  176. quality=95)

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)