You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import numpy as np
  16. import os
  17. import six
  18. import tensorflow as tf
  19. from tqdm import tqdm
  20. from data_gen import DataGen
  21. from sedna.incremental_learning.incremental_learning import IncrementalConfig
  22. from yolo3_multiscale import Yolo3
  23. from yolo3_multiscale import YoloConfig
  24. LOG = logging.getLogger(__name__)
  25. BASE_MODEL_URL = IncrementalConfig().base_model_url
  26. flags = tf.flags.FLAGS
  27. class Interface:
  28. def __init__(self):
  29. """
  30. initialize logging configuration
  31. """
  32. def train(self, train_data, valid_data):
  33. """
  34. train
  35. """
  36. yolo_config = YoloConfig()
  37. data_gen = DataGen(yolo_config, train_data, valid_data)
  38. config = tf.ConfigProto(allow_soft_placement=True)
  39. config.gpu_options.allow_growth = True
  40. with tf.Session(config=config) as sess:
  41. model = Yolo3(sess, True, yolo_config)
  42. if BASE_MODEL_URL and os.path.exists(BASE_MODEL_URL):
  43. LOG.info(f"loading base model, BASE_MODEL_URL={BASE_MODEL_URL}")
  44. saver = tf.train.Saver()
  45. latest_ckpt = tf.train.latest_checkpoint(BASE_MODEL_URL)
  46. LOG.info(f"latest_ckpt={latest_ckpt}")
  47. saver.restore(sess, latest_ckpt)
  48. steps_per_epoch = int(round(data_gen.train_data_size / data_gen.batch_size))
  49. total = steps_per_epoch * flags.max_epochs
  50. with tqdm(desc='Train: ', total=total) as pbar:
  51. for epoch in range(flags.max_epochs):
  52. LOG.info('Epoch %d...' % epoch)
  53. for step in range(steps_per_epoch): # Get a batch and make a step.
  54. batch_data = data_gen.next_batch_train() # get batch data from Queue
  55. if not batch_data:
  56. continue
  57. batch_loss = model.step(sess, batch_data, True)
  58. # pbar.set_description('Train, loss={:.8f}'.format(batch_loss))
  59. pbar.set_description('Train, input_shape=(%d, %d), loss=%.4f' % (
  60. batch_data['input_shape'][0], batch_data['input_shape'][1], batch_loss))
  61. pbar.update()
  62. # LOG.info('validating...')
  63. # val_loss = self.validate(sess, model, data_gen, flags.batch_size)
  64. # LOG.info('loss of validate data : %.2f' % val_loss)
  65. LOG.info("Saving model, global_step: %d" % model.global_step.eval())
  66. checkpoint_path = os.path.join(model.model_dir, "yolo3-epoch%03d.ckpt" % (epoch))
  67. model.saver.save(sess, checkpoint_path, global_step=model.global_step, write_meta_graph=False)
  68. def validate(self, sess, model, data_gen, batch_size):
  69. """
  70. validate
  71. """
  72. total_loss = 0.0
  73. val_steps = int(round(data_gen.val_data_size / batch_size))
  74. if val_steps <= 0:
  75. return -1.0
  76. for _ in range(val_steps): # Get a batch and make a step.
  77. batch_data = data_gen.next_batch_validate()
  78. if not batch_data:
  79. continue
  80. total_loss += model.step(sess, batch_data, False)
  81. return (total_loss / val_steps)
  82. def avg_checkpoints(self):
  83. """
  84. Average the last N checkpoints in the model_dir.
  85. """
  86. LOG.info("average checkpoints start .......")
  87. config = tf.ConfigProto(allow_soft_placement=True)
  88. config.gpu_options.allow_growth = True
  89. with tf.Session(config=config) as sess:
  90. yolo_config = YoloConfig()
  91. model = Yolo3(sess, False, yolo_config)
  92. model_dir = model.model_dir
  93. num_last_checkpoints = 5
  94. global_step = model.global_step.eval()
  95. global_step_name = model.global_step.name.split(":")[0]
  96. checkpoint_state = tf.train.get_checkpoint_state(model_dir)
  97. if not checkpoint_state:
  98. logging.info("# No checkpoint file found in directory: %s" % model_dir)
  99. return None
  100. # Checkpoints are ordered from oldest to newest.
  101. checkpoints = (checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])
  102. if len(checkpoints) < num_last_checkpoints:
  103. logging.info("# Skipping averaging checkpoints because not enough checkpoints is avaliable.")
  104. return None
  105. avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
  106. if not tf.gfile.Exists(avg_model_dir):
  107. logging.info("# Creating new directory %s for saving averaged checkpoints." % avg_model_dir)
  108. tf.gfile.MakeDirs(avg_model_dir)
  109. logging.info("# Reading and averaging variables in checkpoints:")
  110. var_list = tf.contrib.framework.list_variables(checkpoints[0])
  111. var_values, var_dtypes = {}, {}
  112. for (name, shape) in var_list:
  113. if name != global_step_name:
  114. var_values[name] = np.zeros(shape)
  115. for checkpoint in checkpoints:
  116. logging.info(" %s" % checkpoint)
  117. reader = tf.contrib.framework.load_checkpoint(checkpoint)
  118. for name in var_values:
  119. tensor = reader.get_tensor(name)
  120. var_dtypes[name] = tensor.dtype
  121. var_values[name] += tensor
  122. for name in var_values:
  123. var_values[name] /= len(checkpoints)
  124. # Build a graph with same variables in the checkpoints, and save the averaged
  125. # variables into the avg_model_dir.
  126. with tf.Graph().as_default():
  127. tf_vars = [tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
  128. for v in var_values]
  129. placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
  130. assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
  131. global_step_var = tf.Variable(global_step, name=global_step_name, trainable=False)
  132. saver = tf.train.Saver(tf.global_variables())
  133. with tf.Session() as sess:
  134. sess.run(tf.global_variables_initializer())
  135. for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)):
  136. sess.run(assign_op, {p: value})
  137. # Use the built saver to save the averaged checkpoint. Only keep 1
  138. # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
  139. saver.save(sess, os.path.join(avg_model_dir, "translate.ckpt"))
  140. logging.info("average checkpoints end .......")
  141. def save_model_pb(self, saved_model_name):
  142. """
  143. save model as a single pb file from checkpoint
  144. """
  145. logging.info("save model as .pb start .......")
  146. config = tf.ConfigProto(allow_soft_placement=True)
  147. config.gpu_options.allow_growth = True
  148. with tf.Session(config=config) as sess:
  149. yolo_config = YoloConfig()
  150. model = Yolo3(sess, False, yolo_config)
  151. input_graph_def = sess.graph.as_graph_def()
  152. if flags.inference_device == '310D':
  153. output_tensors = model.output
  154. else:
  155. output_tensors = [model.boxes, model.scores, model.classes]
  156. print('output_tensors : ', output_tensors)
  157. output_tensors = [t.op.name for t in output_tensors]
  158. graph = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_tensors)
  159. tf.train.write_graph(graph, model.model_dir, saved_model_name, False)
  160. logging.info("save model as .pb end .......")