You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import logging
  2. import numpy as np
  3. import os
  4. import six
  5. import tensorflow as tf
  6. from tqdm import tqdm
  7. from data_gen import DataGen
  8. from neptune.incremental_learning.incremental_learning import IncrementalConfig
  9. from yolo3_multiscale import Yolo3
  10. from yolo3_multiscale import YoloConfig
  11. LOG = logging.getLogger(__name__)
  12. BASE_MODEL_URL = IncrementalConfig().base_model_url
  13. flags = tf.flags.FLAGS
  14. class Interface:
  15. def __init__(self):
  16. """
  17. initialize logging configuration
  18. """
  19. def train(self, train_data, valid_data):
  20. """
  21. train
  22. """
  23. yolo_config = YoloConfig()
  24. data_gen = DataGen(yolo_config, train_data, valid_data)
  25. config = tf.ConfigProto(allow_soft_placement=True)
  26. config.gpu_options.allow_growth = True
  27. with tf.Session(config=config) as sess:
  28. model = Yolo3(sess, True, yolo_config)
  29. if BASE_MODEL_URL and os.path.exists(BASE_MODEL_URL):
  30. LOG.info(f"loading base model, BASE_MODEL_URL={BASE_MODEL_URL}")
  31. saver = tf.train.Saver()
  32. latest_ckpt = tf.train.latest_checkpoint(BASE_MODEL_URL)
  33. LOG.info(f"latest_ckpt={latest_ckpt}")
  34. saver.restore(sess, latest_ckpt)
  35. steps_per_epoch = int(round(data_gen.train_data_size / data_gen.batch_size))
  36. total = steps_per_epoch * flags.max_epochs
  37. with tqdm(desc='Train: ', total=total) as pbar:
  38. for epoch in range(flags.max_epochs):
  39. LOG.info('Epoch %d...' % epoch)
  40. for step in range(steps_per_epoch): # Get a batch and make a step.
  41. batch_data = data_gen.next_batch_train() # get batch data from Queue
  42. if not batch_data:
  43. continue
  44. batch_loss = model.step(sess, batch_data, True)
  45. # pbar.set_description('Train, loss={:.8f}'.format(batch_loss))
  46. pbar.set_description('Train, input_shape=(%d, %d), loss=%.4f' % (
  47. batch_data['input_shape'][0], batch_data['input_shape'][1], batch_loss))
  48. pbar.update()
  49. # LOG.info('validating...')
  50. # val_loss = self.validate(sess, model, data_gen, flags.batch_size)
  51. # LOG.info('loss of validate data : %.2f' % val_loss)
  52. LOG.info("Saving model, global_step: %d" % model.global_step.eval())
  53. checkpoint_path = os.path.join(model.model_dir, "yolo3-epoch%03d.ckpt" % (epoch))
  54. model.saver.save(sess, checkpoint_path, global_step=model.global_step, write_meta_graph=False)
  55. def validate(self, sess, model, data_gen, batch_size):
  56. """
  57. validate
  58. """
  59. total_loss = 0.0
  60. val_steps = int(round(data_gen.val_data_size / batch_size))
  61. if val_steps <= 0:
  62. return -1.0
  63. for _ in range(val_steps): # Get a batch and make a step.
  64. batch_data = data_gen.next_batch_validate()
  65. if not batch_data:
  66. continue
  67. total_loss += model.step(sess, batch_data, False)
  68. return (total_loss / val_steps)
  69. def avg_checkpoints(self):
  70. """
  71. Average the last N checkpoints in the model_dir.
  72. """
  73. LOG.info("average checkpoints start .......")
  74. config = tf.ConfigProto(allow_soft_placement=True)
  75. config.gpu_options.allow_growth = True
  76. with tf.Session(config=config) as sess:
  77. yolo_config = YoloConfig()
  78. model = Yolo3(sess, False, yolo_config)
  79. model_dir = model.model_dir
  80. num_last_checkpoints = 5
  81. global_step = model.global_step.eval()
  82. global_step_name = model.global_step.name.split(":")[0]
  83. checkpoint_state = tf.train.get_checkpoint_state(model_dir)
  84. if not checkpoint_state:
  85. logging.info("# No checkpoint file found in directory: %s" % model_dir)
  86. return None
  87. # Checkpoints are ordered from oldest to newest.
  88. checkpoints = (checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
  89. if len(checkpoints) < num_last_checkpoints:
  90. logging.info("# Skipping averaging checkpoints because not enough checkpoints is avaliable.")
  91. return None
  92. avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
  93. if not tf.gfile.Exists(avg_model_dir):
  94. logging.info("# Creating new directory %s for saving averaged checkpoints." % avg_model_dir)
  95. tf.gfile.MakeDirs(avg_model_dir)
  96. logging.info("# Reading and averaging variables in checkpoints:")
  97. var_list = tf.contrib.framework.list_variables(checkpoints[0])
  98. var_values, var_dtypes = {}, {}
  99. for (name, shape) in var_list:
  100. if name != global_step_name:
  101. var_values[name] = np.zeros(shape)
  102. for checkpoint in checkpoints:
  103. logging.info(" %s" % checkpoint)
  104. reader = tf.contrib.framework.load_checkpoint(checkpoint)
  105. for name in var_values:
  106. tensor = reader.get_tensor(name)
  107. var_dtypes[name] = tensor.dtype
  108. var_values[name] += tensor
  109. for name in var_values:
  110. var_values[name] /= len(checkpoints)
  111. # Build a graph with same variables in the checkpoints, and save the averaged
  112. # variables into the avg_model_dir.
  113. with tf.Graph().as_default():
  114. tf_vars = [tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
  115. for v in var_values]
  116. placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
  117. assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
  118. global_step_var = tf.Variable(global_step, name=global_step_name, trainable=False)
  119. saver = tf.train.Saver(tf.global_variables())
  120. with tf.Session() as sess:
  121. sess.run(tf.global_variables_initializer())
  122. for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)):
  123. sess.run(assign_op, {p: value})
  124. # Use the built saver to save the averaged checkpoint. Only keep 1
  125. # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
  126. saver.save(sess, os.path.join(avg_model_dir, "translate.ckpt"))
  127. logging.info("average checkpoints end .......")
  128. def save_model_pb(self, saved_model_name):
  129. """
  130. save model as a single pb file from checkpoint
  131. """
  132. logging.info("save model as .pb start .......")
  133. config = tf.ConfigProto(allow_soft_placement=True)
  134. config.gpu_options.allow_growth = True
  135. with tf.Session(config=config) as sess:
  136. yolo_config = YoloConfig()
  137. model = Yolo3(sess, False, yolo_config)
  138. input_graph_def = sess.graph.as_graph_def()
  139. if flags.inference_device == '310D':
  140. output_tensors = model.output
  141. else:
  142. output_tensors = [model.boxes, model.scores, model.classes]
  143. print('output_tensors : ', output_tensors)
  144. output_tensors = [t.op.name for t in output_tensors]
  145. graph = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_tensors)
  146. tf.train.write_graph(graph, model.model_dir, saved_model_name, False)
  147. logging.info("save model as .pb end .......")