Browse Source

Start to train Yolo #359

tags/v0.12
Oceania2018 6 years ago
parent
commit
70df2bbbb7
4 changed files with 23 additions and 6 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +0
    -3
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +20
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs

+ 2
- 1
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -43,7 +43,8 @@ namespace Tensorflow
public ExponentialMovingAverage ExponentialMovingAverage(float decay)
=> new ExponentialMovingAverage(decay);

public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);
public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5)
=> new Saver(var_list: var_list, max_to_keep: max_to_keep);

public string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text);


+ 0
- 3
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,10 +108,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count >= 23868)
{

}
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);


+ 20
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
List<RefVariable> first_stage_trainable_var_list;
Operation train_op_with_frozen_variables;
Operation train_op_with_all_variables;
Saver loader;
Saver saver;
#endregion

public bool Run()
@@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

public void Train(Session sess)
{

sess.run(tf.global_variables_initializer());
print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... ");
loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT);
}

public void Test(Session sess)
@@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
});
});

tf_with(tf.name_scope("loader_and_saver"), delegate
{
loader = tf.train.Saver(net_var);
saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
});

tf_with(tf.name_scope("summary"), delegate
{
tf.summary.scalar("learn_rate", learn_rate);
tf.summary.scalar("giou_loss", giou_loss);
tf.summary.scalar("conf_loss", conf_loss);
tf.summary.scalar("prob_loss", prob_loss);
tf.summary.scalar("total_loss", loss);
});

return graph;
}



+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs View File

@@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
public TrainConfig(string root)
{
_root = root;
INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt");
INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt");
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt");
}
}


Loading…
Cancel
Save