diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index c1e76d11..3a790327 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -43,7 +43,8 @@ namespace Tensorflow public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); - public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); + public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) + => new Saver(var_list: var_list, max_to_keep: max_to_keep); public string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 9029fb8f..15ad511b 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,10 +108,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(tf.get_default_graph()._nodes_by_name.Count >= 23868) - { - } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 73f5c213..8cd4a252 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO List first_stage_trainable_var_list; Operation train_op_with_frozen_variables; Operation train_op_with_all_variables; + Saver loader; + Saver saver; #endregion public bool Run() @@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO public void Train(Session sess) { - + sess.run(tf.global_variables_initializer()); + print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... "); + loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT); } public void Test(Session sess) @@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO }); }); + tf_with(tf.name_scope("loader_and_saver"), delegate + { + loader = tf.train.Saver(net_var); + saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); + }); + + tf_with(tf.name_scope("summary"), delegate + { + tf.summary.scalar("learn_rate", learn_rate); + tf.summary.scalar("giou_loss", giou_loss); + tf.summary.scalar("conf_loss", conf_loss); + tf.summary.scalar("prob_loss", prob_loss); + tf.summary.scalar("total_loss", loss); + }); + return graph; } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs index b5c46151..39308da8 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs @@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO public TrainConfig(string root) { _root = root; - INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); + INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); } }