diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 39915425..54e4aea1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -57,6 +57,12 @@ namespace Tensorflow clear_devices: clear_devices, clear_extraneous_savers: clear_extraneous_savers, strip_default_attrs: strip_default_attrs); + + public string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename); + + public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); } } } diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs index 91596d7d..2c9ae7d9 100644 --- a/src/TensorFlowNET.Core/Estimators/Estimator.cs +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -18,20 +18,35 @@ namespace Tensorflow.Estimators string _model_dir; + Action _model_fn; + public Estimator(Action model_fn, RunConfig config) { _config = config; _model_dir = _config.model_dir; _session_config = _config.session_config; + _model_fn = model_fn; } - public Estimator train(Action input_fn, int max_steps = 1, + public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, _NewCheckpointListenerForEvaluate[] saving_listeners = null) { + if(max_steps > 0) + { + var start_step = _load_global_step_from_checkpoint_dir(_model_dir); + } + _train_model(); throw new NotImplementedException(""); } + private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) + { + var cp = tf.train.latest_checkpoint(checkpoint_dir); + + return 0; + } + private void _train_model() { _train_model_default(); diff --git a/src/TensorFlowNET.Core/Estimators/EvalSpec.cs b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs index ca9ff94f..c5a5820e 100644 --- a/src/TensorFlowNET.Core/Estimators/EvalSpec.cs +++ b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs @@ -6,9 +6,11 @@ namespace Tensorflow.Estimators { public class EvalSpec { + string _name; + public EvalSpec(string name, Action input_fn, FinalExporter exporters) { - + _name = name; } } } diff --git a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs index 2252b53c..64b3a829 100644 --- a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs +++ b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs @@ -6,11 +6,16 @@ namespace Tensorflow.Estimators { public class TrainSpec { - public int max_steps { get; set; } + int _max_steps; + public int max_steps => _max_steps; + + Action _input_fn; + public Action input_fn => _input_fn; public TrainSpec(Action input_fn, int max_steps) { - this.max_steps = max_steps; + _max_steps = max_steps; + _input_fn = input_fn; } } } diff --git a/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs index 2850255f..71464562 100644 --- a/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs +++ b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs @@ -6,5 +6,11 @@ namespace Tensorflow.Estimators { public class _NewCheckpointListenerForEvaluate { + _Evaluator _evaluator; + + public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs) + { + _evaluator = evaluator; + } } } diff --git a/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs index 88cb7638..e7ad6905 100644 --- a/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs +++ b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs @@ -32,15 +32,17 @@ namespace Tensorflow.Estimators /// private void run_local() { + var train_hooks = new Action[0]; Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + "after every checkpoint. Checkpoint frequency is determined " + $"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + $"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); - /*_estimator.train(input_fn: _train_spec.input_fn, + var saving_listeners = new _NewCheckpointListenerForEvaluate[0]; + _estimator.train(input_fn: _train_spec.input_fn, max_steps: _train_spec.max_steps, hooks: train_hooks, - saving_listeners: saving_listeners);*/ + saving_listeners: saving_listeners); } } } diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs index 4ff538db..c23368c1 100644 --- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using static Tensorflow.SaverDef.Types; +using static Tensorflow.Binding; namespace Tensorflow { @@ -144,5 +145,54 @@ namespace Tensorflow return prefix + ".index"; return prefix; } + + /// + /// Finds the filename of latest saved checkpoint file. + /// + /// + /// + /// + public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + { + // Pick the latest checkpoint based on checkpoint state. + var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename); + if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + { + // Look for either a V2 path or a V1 path, with priority for V2. + var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2); + var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1); + if (File.Exists(v2_path) || File.Exists(v1_path)) + return ckpt.ModelCheckpointPath; + else + throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}"); + } + return null; + } + + public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + { + var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); + if (File.Exists(coord_checkpoint_filename)) + { + var file_content = File.ReadAllBytes(coord_checkpoint_filename); + var ckpt = CheckpointState.Parser.ParseFrom(file_content); + if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); + // For relative model_checkpoint_path and all_model_checkpoint_paths, + // prepend checkpoint_dir. + if (!Path.IsPathRooted(ckpt.ModelCheckpointPath)) + ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath); + foreach(var i in range(len(ckpt.AllModelCheckpointPaths))) + { + var p = ckpt.AllModelCheckpointPaths[i]; + if (!Path.IsPathRooted(p)) + ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p); + } + + return ckpt; + } + + return null; + } } } diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs index 33c5c236..6759e03e 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs @@ -19,19 +19,28 @@ namespace Tensorflow.Models.ObjectDetection int sample_1_of_n_eval_on_train_examples = 1) { var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); + + // Create the input functions for TRAIN/EVAL/PREDICT. + Action train_input_fn = () => { }; + var eval_input_configs = config.EvalInputReader; var eval_input_fns = new Action[eval_input_configs.Count]; var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray(); + Action eval_on_train_input_fn = () => { }; + Action predict_input_fn = () => { }; Action model_fn = () => { }; var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config); return new TrainAndEvalDict { estimator = estimator, - train_steps = train_steps, + train_input_fn = train_input_fn, eval_input_fns = eval_input_fns, - eval_input_names = eval_input_names + eval_input_names = eval_input_names, + eval_on_train_input_fn = eval_on_train_input_fn, + predict_input_fn = predict_input_fn, + train_steps = train_steps }; } @@ -46,10 +55,7 @@ namespace Tensorflow.Models.ObjectDetection .Select(x => x.ToString()) .ToArray(); - var eval_specs = new List() - { - new EvalSpec("", null, null) // for test. - }; + var eval_specs = new List(); foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList())) { var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}"; diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs index e879ec80..a31d21be 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs @@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model"; string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config"; - int num_train_steps = 1; + int num_train_steps = 50; int sample_1_of_n_eval_examples = 1; int sample_1_of_n_eval_on_train_examples = 5;