diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs
index 39915425..54e4aea1 100644
--- a/src/TensorFlowNET.Core/APIs/tf.train.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.train.cs
@@ -57,6 +57,12 @@ namespace Tensorflow
clear_devices: clear_devices,
clear_extraneous_savers: clear_extraneous_savers,
strip_default_attrs: strip_default_attrs);
+
+ public string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
+ => checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename);
+
+ public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
+ => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename);
}
}
}
diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs
index 91596d7d..2c9ae7d9 100644
--- a/src/TensorFlowNET.Core/Estimators/Estimator.cs
+++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs
@@ -18,20 +18,35 @@ namespace Tensorflow.Estimators
string _model_dir;
+ Action _model_fn;
+
public Estimator(Action model_fn, RunConfig config)
{
_config = config;
_model_dir = _config.model_dir;
_session_config = _config.session_config;
+ _model_fn = model_fn;
}
- public Estimator train(Action input_fn, int max_steps = 1,
+ public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate[] saving_listeners = null)
{
+ if(max_steps > 0)
+ {
+ var start_step = _load_global_step_from_checkpoint_dir(_model_dir);
+ }
+
_train_model();
throw new NotImplementedException("");
}
+ private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
+ {
+ var cp = tf.train.latest_checkpoint(checkpoint_dir);
+
+ return 0;
+ }
+
private void _train_model()
{
_train_model_default();
diff --git a/src/TensorFlowNET.Core/Estimators/EvalSpec.cs b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs
index ca9ff94f..c5a5820e 100644
--- a/src/TensorFlowNET.Core/Estimators/EvalSpec.cs
+++ b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs
@@ -6,9 +6,11 @@ namespace Tensorflow.Estimators
{
public class EvalSpec
{
+ string _name;
+
public EvalSpec(string name, Action input_fn, FinalExporter exporters)
{
-
+ _name = name;
}
}
}
diff --git a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs
index 2252b53c..64b3a829 100644
--- a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs
+++ b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs
@@ -6,11 +6,16 @@ namespace Tensorflow.Estimators
{
public class TrainSpec
{
- public int max_steps { get; set; }
+ int _max_steps;
+ public int max_steps => _max_steps;
+
+ Action _input_fn;
+ public Action input_fn => _input_fn;
public TrainSpec(Action input_fn, int max_steps)
{
- this.max_steps = max_steps;
+ _max_steps = max_steps;
+ _input_fn = input_fn;
}
}
}
diff --git a/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs
index 2850255f..71464562 100644
--- a/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs
+++ b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs
@@ -6,5 +6,11 @@ namespace Tensorflow.Estimators
{
public class _NewCheckpointListenerForEvaluate
{
+ _Evaluator _evaluator;
+
+ public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs)
+ {
+ _evaluator = evaluator;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
index 88cb7638..e7ad6905 100644
--- a/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
+++ b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
@@ -32,15 +32,17 @@ namespace Tensorflow.Estimators
///
private void run_local()
{
+ var train_hooks = new Action[0];
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
"after every checkpoint. Checkpoint frequency is determined " +
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps);
- /*_estimator.train(input_fn: _train_spec.input_fn,
+ var saving_listeners = new _NewCheckpointListenerForEvaluate[0];
+ _estimator.train(input_fn: _train_spec.input_fn,
max_steps: _train_spec.max_steps,
hooks: train_hooks,
- saving_listeners: saving_listeners);*/
+ saving_listeners: saving_listeners);
}
}
}
diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
index 4ff538db..c23368c1 100644
--- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using static Tensorflow.SaverDef.Types;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -144,5 +145,54 @@ namespace Tensorflow
return prefix + ".index";
return prefix;
}
+
+ ///
+ /// Finds the filename of latest saved checkpoint file.
+ ///
+ ///
+ ///
+ ///
+ public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
+ {
+ // Pick the latest checkpoint based on checkpoint state.
+ var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename);
+ if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
+ {
+ // Look for either a V2 path or a V1 path, with priority for V2.
+ var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2);
+ var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1);
+ if (File.Exists(v2_path) || File.Exists(v1_path))
+ return ckpt.ModelCheckpointPath;
+ else
+ throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}");
+ }
+ return null;
+ }
+
+ public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
+ {
+ var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename);
+ if (File.Exists(coord_checkpoint_filename))
+ {
+ var file_content = File.ReadAllBytes(coord_checkpoint_filename);
+ var ckpt = CheckpointState.Parser.ParseFrom(file_content);
+ if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
+ throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}");
+ // For relative model_checkpoint_path and all_model_checkpoint_paths,
+ // prepend checkpoint_dir.
+ if (!Path.IsPathRooted(ckpt.ModelCheckpointPath))
+ ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath);
+ foreach(var i in range(len(ckpt.AllModelCheckpointPaths)))
+ {
+ var p = ckpt.AllModelCheckpointPaths[i];
+ if (!Path.IsPathRooted(p))
+ ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p);
+ }
+
+ return ckpt;
+ }
+
+ return null;
+ }
}
}
diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs
index 33c5c236..6759e03e 100644
--- a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs
+++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs
@@ -19,19 +19,28 @@ namespace Tensorflow.Models.ObjectDetection
int sample_1_of_n_eval_on_train_examples = 1)
{
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);
+
+ // Create the input functions for TRAIN/EVAL/PREDICT.
+ Action train_input_fn = () => { };
+
var eval_input_configs = config.EvalInputReader;
var eval_input_fns = new Action[eval_input_configs.Count];
var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray();
+ Action eval_on_train_input_fn = () => { };
+ Action predict_input_fn = () => { };
Action model_fn = () => { };
var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config);
return new TrainAndEvalDict
{
estimator = estimator,
- train_steps = train_steps,
+ train_input_fn = train_input_fn,
eval_input_fns = eval_input_fns,
- eval_input_names = eval_input_names
+ eval_input_names = eval_input_names,
+ eval_on_train_input_fn = eval_on_train_input_fn,
+ predict_input_fn = predict_input_fn,
+ train_steps = train_steps
};
}
@@ -46,10 +55,7 @@ namespace Tensorflow.Models.ObjectDetection
.Select(x => x.ToString())
.ToArray();
- var eval_specs = new List()
- {
- new EvalSpec("", null, null) // for test.
- };
+ var eval_specs = new List();
foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList()))
{
var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}";
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
index e879ec80..a31d21be 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs
@@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection
string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model";
string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config";
- int num_train_steps = 1;
+ int num_train_steps = 50;
int sample_1_of_n_eval_examples = 1;
int sample_1_of_n_eval_on_train_examples = 5;