You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import logging
  2. import tensorflow as tf
  3. import neptune
  4. from interface import Interface
  5. from neptune.incremental_learning import IncrementalConfig
  6. LOG = logging.getLogger(__name__)
  7. MODEL_URL = IncrementalConfig().model_url
  8. def main():
  9. tf.set_random_seed(22)
  10. class_names = neptune.context.get_parameters("class_names")
  11. # load dataset.
  12. train_data = neptune.load_train_dataset(data_format='txt',
  13. with_image=False)
  14. # read parameters from deployment config.
  15. obj_threshold = neptune.context.get_parameters("obj_threshold")
  16. nms_threshold = neptune.context.get_parameters("nms_threshold")
  17. input_shape = neptune.context.get_parameters("input_shape")
  18. epochs = neptune.context.get_parameters('epochs')
  19. batch_size = neptune.context.get_parameters('batch_size')
  20. tf.flags.DEFINE_string('train_url', default=MODEL_URL,
  21. help='train url for model')
  22. tf.flags.DEFINE_string('log_url', default=None, help='log url for model')
  23. tf.flags.DEFINE_string('checkpoint_url', default=None,
  24. help='checkpoint url for model')
  25. tf.flags.DEFINE_string('model_name', default=None,
  26. help='url for train annotation files')
  27. tf.flags.DEFINE_list('class_names', default=class_names.split(','),
  28. # 'helmet,helmet-on,person,helmet-off'
  29. help='label names for the training datasets')
  30. tf.flags.DEFINE_list('input_shape',
  31. default=[int(x) for x in input_shape.split(',')],
  32. help='input_shape') # [352, 640]
  33. tf.flags.DEFINE_integer('max_epochs', default=epochs,
  34. help='training number of epochs')
  35. tf.flags.DEFINE_integer('batch_size', default=batch_size,
  36. help='training batch size')
  37. tf.flags.DEFINE_boolean('load_imagenet_weights', default=False,
  38. help='if load imagenet weights or not')
  39. tf.flags.DEFINE_string('inference_device',
  40. default='GPU',
  41. help='which type of device is used to do inference,'
  42. ' only CPU, GPU or 310D')
  43. tf.flags.DEFINE_boolean('copy_to_local', default=True,
  44. help='if load imagenet weights or not')
  45. tf.flags.DEFINE_integer('num_gpus', default=1, help='use number of gpus')
  46. tf.flags.DEFINE_boolean('finetuning', default=False,
  47. help='use number of gpus')
  48. tf.flags.DEFINE_boolean('label_changed', default=False,
  49. help='whether number of labels is changed or not')
  50. tf.flags.DEFINE_string('learning_rate', default='0.001',
  51. help='label names for the training datasets')
  52. tf.flags.DEFINE_string('obj_threshold', default=obj_threshold,
  53. help='label names for the training datasets')
  54. tf.flags.DEFINE_string('nms_threshold', default=nms_threshold,
  55. help='label names for the training datasets')
  56. tf.flags.DEFINE_string('net_type', default='resnet18',
  57. help='resnet18 or resnet18_nas')
  58. tf.flags.DEFINE_string('nas_sequence', default='64_1-2111-2-1112',
  59. help='resnet18 or resnet18_nas')
  60. tf.flags.DEFINE_string('deploy_model_format', default=None,
  61. help='the format for the converted model')
  62. tf.flags.DEFINE_string('result_url', default=None,
  63. help='result url for training')
  64. model = Interface()
  65. neptune.incremental_learning.train(model=model,
  66. train_data=train_data,
  67. epochs=epochs,
  68. batch_size=batch_size,
  69. class_names=class_names,
  70. input_shape=input_shape,
  71. obj_threshold=obj_threshold,
  72. nms_threshold=nms_threshold)
  73. if __name__ == '__main__':
  74. main()