You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

calculate_weights.py 954 B

12345678910111213141516171819202122232425262728
  1. import os
  2. from tqdm import tqdm
  3. import numpy as np
  4. def calculate_weigths_labels(class_weight_path, dataloader, num_classes):
  5. # Create an instance from the data loader
  6. z = np.zeros((num_classes,))
  7. # Initialize tqdm
  8. tqdm_batch = tqdm(dataloader)
  9. print('Calculating classes weights')
  10. for sample in tqdm_batch:
  11. y = sample['label']
  12. y = y.detach().cpu().numpy()
  13. mask = (y >= 0) & (y < num_classes)
  14. labels = y[mask].astype(np.uint8)
  15. count_l = np.bincount(labels, minlength=num_classes)
  16. z += count_l
  17. tqdm_batch.close()
  18. total_frequency = np.sum(z)
  19. class_weights = []
  20. for frequency in z:
  21. class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
  22. class_weights.append(class_weight)
  23. ret = np.array(class_weights)
  24. classes_weights_path = os.path.join(class_weight_path, 'classes_weights.npy')
  25. np.save(classes_weights_path, ret)
  26. return ret