You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import logging
  2. import tensorflow as tf
  3. import neptune
  4. from interface import Interface
  5. from neptune.incremental_learning.incremental_learning import IncrementalConfig
  6. LOG = logging.getLogger(__name__)
  7. MODEL_URL = IncrementalConfig().model_url
  8. def main():
  9. tf.set_random_seed(22)
  10. class_names = neptune.context.get_parameters("class_names")
  11. # load dataset.
  12. train_data = neptune.load_train_dataset(data_format='txt', with_image=False)
  13. # read parameters from deployment config.
  14. obj_threshold = neptune.context.get_parameters("obj_threshold")
  15. nms_threshold = neptune.context.get_parameters("nms_threshold")
  16. input_shape = neptune.context.get_parameters("input_shape")
  17. epochs = neptune.context.get_parameters('epochs')
  18. batch_size = neptune.context.get_parameters('batch_size')
  19. tf.flags.DEFINE_string('train_url', default=MODEL_URL, help='train url for model')
  20. tf.flags.DEFINE_string('log_url', default=None, help='log url for model')
  21. tf.flags.DEFINE_string('checkpoint_url', default=None, help='checkpoint url for model')
  22. tf.flags.DEFINE_string('model_name', default=None, help='url for train annotation files')
  23. tf.flags.DEFINE_list('class_names', default=class_names.split(','), # 'helmet,helmet-on,person,helmet-off'
  24. help='label names for the training datasets')
  25. tf.flags.DEFINE_list('input_shape', default=[int(x) for x in input_shape.split(',')],
  26. help='input_shape') # [352, 640]
  27. tf.flags.DEFINE_integer('max_epochs', default=epochs, help='training number of epochs')
  28. tf.flags.DEFINE_integer('batch_size', default=batch_size, help='training batch size')
  29. tf.flags.DEFINE_boolean('load_imagenet_weights', default=False, help='if load imagenet weights or not')
  30. tf.flags.DEFINE_string('inference_device',
  31. default='GPU',
  32. help='which type of device is used to do inference, only CPU, GPU or 310D')
  33. tf.flags.DEFINE_boolean('copy_to_local', default=True, help='if load imagenet weights or not')
  34. tf.flags.DEFINE_integer('num_gpus', default=1, help='use number of gpus')
  35. tf.flags.DEFINE_boolean('finetuning', default=False, help='use number of gpus')
  36. tf.flags.DEFINE_boolean('label_changed', default=False, help='whether number of labels is changed or not')
  37. tf.flags.DEFINE_string('learning_rate', default='0.001', help='label names for the training datasets')
  38. tf.flags.DEFINE_string('obj_threshold', default=obj_threshold, help='label names for the training datasets')
  39. tf.flags.DEFINE_string('nms_threshold', default=nms_threshold, help='label names for the training datasets')
  40. tf.flags.DEFINE_string('net_type', default='resnet18', help='resnet18 or resnet18_nas')
  41. tf.flags.DEFINE_string('nas_sequence', default='64_1-2111-2-1112', help='resnet18 or resnet18_nas')
  42. tf.flags.DEFINE_string('deploy_model_format', default=None, help='the format for the converted model')
  43. tf.flags.DEFINE_string('result_url', default=None, help='result url for training')
  44. model = Interface()
  45. model = neptune.incremental_learning.train(model=model,
  46. train_data=train_data,
  47. epochs=epochs,
  48. batch_size=batch_size,
  49. class_names=class_names,
  50. input_shape=input_shape,
  51. obj_threshold=obj_threshold,
  52. nms_threshold=nms_threshold)
  53. # Save the model based on the config.
  54. # neptune.save_model(model)
  55. if __name__ == '__main__':
  56. main()