You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import numpy as np
  3. import random
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. from example_files.model import ConvModel
  9. class ImageDataLoader:
  10. def __init__(self, data_root, train: bool = True):
  11. self.data_root = data_root
  12. self.train = train
  13. def get_idx_data(self, idx=0):
  14. if self.train:
  15. X_path = os.path.join(self.data_root, "uploader", "uploader_%d_X.npy" % (idx))
  16. y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.npy" % (idx))
  17. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  18. raise Exception("Index Error")
  19. X = np.load(X_path)
  20. y = np.load(y_path)
  21. else:
  22. X_path = os.path.join(self.data_root, "user", "user_%d_X.npy" % (idx))
  23. y_path = os.path.join(self.data_root, "user", "user_%d_y.npy" % (idx))
  24. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  25. raise Exception("Index Error")
  26. X = np.load(X_path)
  27. y = np.load(y_path)
  28. return X, y
  29. def generate_uploader(data_x, data_y, n_uploaders=50, data_save_root=None):
  30. if data_save_root is None:
  31. return
  32. os.makedirs(data_save_root, exist_ok=True)
  33. for i in range(n_uploaders):
  34. random_class_num = random.randint(6, 10)
  35. cls_indx = list(range(10))
  36. random.shuffle(cls_indx)
  37. selected_cls_indx = cls_indx[:random_class_num]
  38. rest_cls_indx = cls_indx[random_class_num:]
  39. selected_data_indx = []
  40. for cls in selected_cls_indx:
  41. data_indx = list(torch.where(data_y == cls)[0])
  42. # print(type(data_indx))
  43. random.shuffle(data_indx)
  44. data_num = random.randint(800, 2000)
  45. selected_indx = data_indx[:data_num]
  46. selected_data_indx = selected_data_indx + selected_indx
  47. for cls in rest_cls_indx:
  48. flag = random.randint(0, 1)
  49. if flag == 0:
  50. continue
  51. data_indx = list(torch.where(data_y == cls)[0])
  52. random.shuffle(data_indx)
  53. data_num = random.randint(20, 80)
  54. selected_indx = data_indx[:data_num]
  55. selected_data_indx = selected_data_indx + selected_indx
  56. selected_X = data_x[selected_data_indx].numpy()
  57. selected_y = data_y[selected_data_indx].numpy()
  58. print(selected_X.dtype, selected_y.dtype)
  59. print(selected_X.shape, selected_y.shape)
  60. X_save_dir = os.path.join(data_save_root, "uploader_%d_X.npy" % (i))
  61. y_save_dir = os.path.join(data_save_root, "uploader_%d_y.npy" % (i))
  62. np.save(X_save_dir, selected_X)
  63. np.save(y_save_dir, selected_y)
  64. print("Saving to %s" % (X_save_dir))
  65. def generate_user(data_x, data_y, n_users=50, data_save_root=None):
  66. if data_save_root is None:
  67. return
  68. os.makedirs(data_save_root, exist_ok=True)
  69. for i in range(n_users):
  70. random_class_num = random.randint(3, 6)
  71. cls_indx = list(range(10))
  72. random.shuffle(cls_indx)
  73. selected_cls_indx = cls_indx[:random_class_num]
  74. selected_data_indx = []
  75. for cls in selected_cls_indx:
  76. data_indx = list(torch.where(data_y == cls)[0])
  77. # print(type(data_indx))
  78. random.shuffle(data_indx)
  79. data_num = random.randint(150, 350)
  80. selected_indx = data_indx[:data_num]
  81. selected_data_indx = selected_data_indx + selected_indx
  82. # print('Total Index:', len(selected_data_indx))
  83. selected_X = data_x[selected_data_indx].numpy()
  84. selected_y = data_y[selected_data_indx].numpy()
  85. print(selected_X.shape, selected_y.shape)
  86. X_save_dir = os.path.join(data_save_root, "user_%d_X.npy" % (i))
  87. y_save_dir = os.path.join(data_save_root, "user_%d_y.npy" % (i))
  88. np.save(X_save_dir, selected_X)
  89. np.save(y_save_dir, selected_y)
  90. print("Saving to %s" % (X_save_dir))
  91. # Train Uploaders' models
  92. def train(X, y, out_classes, epochs=35, batch_size=128):
  93. print(X.shape, y.shape)
  94. input_feature = X.shape[1]
  95. data_size = X.shape[0]
  96. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  97. model = ConvModel(channel=input_feature, n_random_features=out_classes).to(device)
  98. model.train()
  99. # Adam optimizer with learning rate 1e-3
  100. # optimizer = optim.Adam(model.parameters(), lr=1e-3)
  101. # SGD optimizer with learning rate 1e-2
  102. optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
  103. # mean-squared error loss
  104. criterion = nn.CrossEntropyLoss()
  105. for epoch in range(epochs):
  106. running_loss = []
  107. indx = list(range(data_size))
  108. random.shuffle(indx)
  109. curr_X = X[indx]
  110. curr_y = y[indx]
  111. for i in range(math.floor(data_size / batch_size)):
  112. inputs, annos = curr_X[i * batch_size : (i + 1) * batch_size], curr_y[i * batch_size : (i + 1) * batch_size]
  113. inputs = torch.from_numpy(inputs).to(device)
  114. annos = torch.from_numpy(annos).to(device)
  115. # print(inputs.dtype, annos.dtype)
  116. out = model(inputs)
  117. optimizer.zero_grad()
  118. loss = criterion(out, annos)
  119. loss.backward()
  120. optimizer.step()
  121. running_loss.append(loss.item())
  122. # print('Epoch: %d, Average Loss: %.3f'%(epoch+1, np.mean(running_loss)))
  123. # Train Accuracy
  124. acc = test(X, y, model)
  125. model.train()
  126. return model
  127. def test(test_X, test_y, model, batch_size=128):
  128. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  129. model.eval()
  130. total, correct = 0, 0
  131. data_size = test_X.shape[0]
  132. for i in range(math.ceil(data_size / batch_size)):
  133. inputs, annos = test_X[i * batch_size : (i + 1) * batch_size], test_y[i * batch_size : (i + 1) * batch_size]
  134. inputs = torch.Tensor(inputs).to(device)
  135. annos = torch.Tensor(annos).to(device)
  136. out = model(inputs)
  137. _, predicted = torch.max(out.data, 1)
  138. total += annos.size(0)
  139. correct += (predicted == annos).sum().item()
  140. acc = correct / total * 100
  141. print("Accuracy: %.2f" % (acc))
  142. return acc
  143. def eval_prediction(pred_y, target_y):
  144. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  145. if not isinstance(pred_y, np.ndarray):
  146. pred_y = pred_y.detach().cpu().numpy()
  147. predicted = np.argmax(pred_y, 1)
  148. # print(predicted)
  149. # annos = torch.from_numpy(target_y).to(device)
  150. annos = target_y
  151. total = annos.shape[0]
  152. correct = (predicted == annos).sum().item()
  153. criterion = nn.CrossEntropyLoss()
  154. return correct / total