You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

args.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import torch
  3. from sedna.common.config import BaseConfig
  4. class Arguments:
  5. '''
  6. Setting basic arguments for RFNet model
  7. '''
  8. def __init__(self, **kwargs):
  9. # whether to use depth images or not
  10. self.depth = kwargs.get("depth", False)
  11. # number of dataloader threads
  12. self.workers = int(kwargs.get("workers", 0))
  13. self.base_size = int(kwargs.get("base-size", 1024)) # base image size
  14. self.crop_size = int(kwargs.get("crop_size", 768)) # crop image size
  15. self.image_size = kwargs.get(
  16. "image_size", (2048, 1024)) # output image shape
  17. # input batch size for training
  18. self.batch_size = int(kwargs.get("batch_size", 4))
  19. self.val_batch_size = int(kwargs.get(
  20. "val_batch_size", 1)) # input batch size for validation
  21. self.test_batch_size = int(kwargs.get(
  22. "test_batch_size", 1)) # input batch size for testing
  23. self.num_class = int(kwargs.get(
  24. "num_class", 31)) # number of training classes
  25. # whether to disable CUDA for training
  26. self.no_cuda = kwargs.get("no_cuda", False)
  27. # use which gpu to train which must be a comma-separated list of
  28. # integers only
  29. self.gpu_ids = kwargs.get("gpu_ids", "0, 1")
  30. self.checkname = kwargs.get(
  31. "checkname", "RFNet") # set the checkpoint name
  32. self.cuda = not self.no_cuda and torch.cuda.is_available()
  33. if self.cuda:
  34. try:
  35. self.gpu_ids = [int(s) for s in self.gpu_ids.split(',')]
  36. except ValueError:
  37. raise ValueError(
  38. 'Argument --gpu_ids must be a comma-separated list of integers only')
  39. class TrainingArguments(Arguments):
  40. '''
  41. Setting basic arguments for RFNet training
  42. '''
  43. def __init__(self, **kwargs):
  44. super(TrainingArguments, self).__init__(**kwargs)
  45. self.loss_type = kwargs.pop('loss_type', "ce") # loss function type
  46. # number of epochs to train
  47. self.epochs = int(kwargs.get("epochs", 200))
  48. # the index of epoch to start training
  49. self.start_epoch = int(kwargs.get("start_epoch", 0))
  50. self.use_balanced_weights = kwargs.get(
  51. "use_balanced_weights",
  52. False) # whether to use balanced weights
  53. # if use balanced weights, specify weight path
  54. self.class_weight_path = kwargs.get("class_weight_path", None)
  55. self.lr = float(kwargs.get("lr", 1e-4)) # learning rate
  56. self.lr_scheduler = kwargs.get(
  57. "lr_scheduler", "cos") # lr scheduler mode
  58. self.momentum = float(kwargs.get("momentum", 0.9))
  59. self.weight_decay = float(kwargs.get("weight_decay", 2.5e-5))
  60. self.seed = int(kwargs.get("seed", 1)) # random seed
  61. # put the path to resuming file if needed
  62. self.resume = kwargs.get("resume", None)
  63. # whether to finetune on a different dataset
  64. self.ft = kwargs.get("ft", True)
  65. self.eval_interval = int(
  66. kwargs.get(
  67. "eval_interval",
  68. 100)) # evaluation interval
  69. # whether to skip validation during training
  70. self.no_val = kwargs.get("no_val", True)
  71. if not self.batch_size:
  72. self.batch_size = 4 * len(self.gpu_ids)
  73. torch.manual_seed(self.seed)
  74. class EvaluationArguments(Arguments):
  75. '''
  76. Setting basic arguments for RFNet evaluation
  77. '''
  78. def __init__(self, **kwargs):
  79. super(EvaluationArguments, self).__init__(**kwargs)
  80. self.weight_path = kwargs.get('weight_path') # path of the weight
  81. # whether to merge images and labels
  82. self.merge = kwargs.get('merge', True)
  83. self.save_predicted_image = kwargs.get(
  84. 'save_predicted_image',
  85. False) # whether to save the predicted images
  86. self.color_label_save_path = kwargs.get('color_label_save_path', os.path.join(
  87. BaseConfig.data_path_prefix, "inference_results/color")) # path to save colored label images
  88. self.merge_label_save_path = kwargs.get('merge_label_save_path', os.path.join(
  89. BaseConfig.data_path_prefix, "inference_results/merge")) # path to save merged label images
  90. self.label_save_path = kwargs.get("label_save_path", os.path.join(
  91. BaseConfig.data_path_prefix, "inference_results/label")) # path to save label images