@@ -57,6 +57,12 @@ namespace Tensorflow | |||
clear_devices: clear_devices, | |||
clear_extraneous_savers: clear_extraneous_savers, | |||
strip_default_attrs: strip_default_attrs); | |||
public string latest_checkpoint(string checkpoint_dir, string latest_filename = null) | |||
=> checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename); | |||
public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | |||
} | |||
} | |||
} |
@@ -18,20 +18,35 @@ namespace Tensorflow.Estimators | |||
string _model_dir; | |||
Action _model_fn; | |||
public Estimator(Action model_fn, RunConfig config) | |||
{ | |||
_config = config; | |||
_model_dir = _config.model_dir; | |||
_session_config = _config.session_config; | |||
_model_fn = model_fn; | |||
} | |||
public Estimator train(Action input_fn, int max_steps = 1, | |||
public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null, | |||
_NewCheckpointListenerForEvaluate[] saving_listeners = null) | |||
{ | |||
if(max_steps > 0) | |||
{ | |||
var start_step = _load_global_step_from_checkpoint_dir(_model_dir); | |||
} | |||
_train_model(); | |||
throw new NotImplementedException(""); | |||
} | |||
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||
{ | |||
var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
return 0; | |||
} | |||
private void _train_model() | |||
{ | |||
_train_model_default(); | |||
@@ -6,9 +6,11 @@ namespace Tensorflow.Estimators | |||
{ | |||
public class EvalSpec | |||
{ | |||
string _name; | |||
public EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||
{ | |||
_name = name; | |||
} | |||
} | |||
} |
@@ -6,11 +6,16 @@ namespace Tensorflow.Estimators | |||
{ | |||
public class TrainSpec | |||
{ | |||
public int max_steps { get; set; } | |||
int _max_steps; | |||
public int max_steps => _max_steps; | |||
Action _input_fn; | |||
public Action input_fn => _input_fn; | |||
public TrainSpec(Action input_fn, int max_steps) | |||
{ | |||
this.max_steps = max_steps; | |||
_max_steps = max_steps; | |||
_input_fn = input_fn; | |||
} | |||
} | |||
} |
@@ -6,5 +6,11 @@ namespace Tensorflow.Estimators | |||
{ | |||
public class _NewCheckpointListenerForEvaluate | |||
{ | |||
_Evaluator _evaluator; | |||
public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs) | |||
{ | |||
_evaluator = evaluator; | |||
} | |||
} | |||
} |
@@ -32,15 +32,17 @@ namespace Tensorflow.Estimators | |||
/// </summary> | |||
private void run_local() | |||
{ | |||
var train_hooks = new Action[0]; | |||
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||
"after every checkpoint. Checkpoint frequency is determined " + | |||
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); | |||
/*_estimator.train(input_fn: _train_spec.input_fn, | |||
var saving_listeners = new _NewCheckpointListenerForEvaluate[0]; | |||
_estimator.train(input_fn: _train_spec.input_fn, | |||
max_steps: _train_spec.max_steps, | |||
hooks: train_hooks, | |||
saving_listeners: saving_listeners);*/ | |||
saving_listeners: saving_listeners); | |||
} | |||
} | |||
} |
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using static Tensorflow.SaverDef.Types; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -144,5 +145,54 @@ namespace Tensorflow | |||
return prefix + ".index"; | |||
return prefix; | |||
} | |||
/// <summary> | |||
/// Finds the filename of latest saved checkpoint file. | |||
/// </summary> | |||
/// <param name="checkpoint_dir"></param> | |||
/// <param name="latest_filename"></param> | |||
/// <returns></returns> | |||
public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null) | |||
{ | |||
// Pick the latest checkpoint based on checkpoint state. | |||
var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename); | |||
if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
{ | |||
// Look for either a V2 path or a V1 path, with priority for V2. | |||
var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2); | |||
var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1); | |||
if (File.Exists(v2_path) || File.Exists(v1_path)) | |||
return ckpt.ModelCheckpointPath; | |||
else | |||
throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}"); | |||
} | |||
return null; | |||
} | |||
public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
{ | |||
var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); | |||
if (File.Exists(coord_checkpoint_filename)) | |||
{ | |||
var file_content = File.ReadAllBytes(coord_checkpoint_filename); | |||
var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); | |||
// For relative model_checkpoint_path and all_model_checkpoint_paths, | |||
// prepend checkpoint_dir. | |||
if (!Path.IsPathRooted(ckpt.ModelCheckpointPath)) | |||
ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath); | |||
foreach(var i in range(len(ckpt.AllModelCheckpointPaths))) | |||
{ | |||
var p = ckpt.AllModelCheckpointPaths[i]; | |||
if (!Path.IsPathRooted(p)) | |||
ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p); | |||
} | |||
return ckpt; | |||
} | |||
return null; | |||
} | |||
} | |||
} |
@@ -19,19 +19,28 @@ namespace Tensorflow.Models.ObjectDetection | |||
int sample_1_of_n_eval_on_train_examples = 1) | |||
{ | |||
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); | |||
// Create the input functions for TRAIN/EVAL/PREDICT. | |||
Action train_input_fn = () => { }; | |||
var eval_input_configs = config.EvalInputReader; | |||
var eval_input_fns = new Action[eval_input_configs.Count]; | |||
var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray(); | |||
Action eval_on_train_input_fn = () => { }; | |||
Action predict_input_fn = () => { }; | |||
Action model_fn = () => { }; | |||
var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config); | |||
return new TrainAndEvalDict | |||
{ | |||
estimator = estimator, | |||
train_steps = train_steps, | |||
train_input_fn = train_input_fn, | |||
eval_input_fns = eval_input_fns, | |||
eval_input_names = eval_input_names | |||
eval_input_names = eval_input_names, | |||
eval_on_train_input_fn = eval_on_train_input_fn, | |||
predict_input_fn = predict_input_fn, | |||
train_steps = train_steps | |||
}; | |||
} | |||
@@ -46,10 +55,7 @@ namespace Tensorflow.Models.ObjectDetection | |||
.Select(x => x.ToString()) | |||
.ToArray(); | |||
var eval_specs = new List<EvalSpec>() | |||
{ | |||
new EvalSpec("", null, null) // for test. | |||
}; | |||
var eval_specs = new List<EvalSpec>(); | |||
foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList())) | |||
{ | |||
var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}"; | |||
@@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection | |||
string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model"; | |||
string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config"; | |||
int num_train_steps = 1; | |||
int num_train_steps = 50; | |||
int sample_1_of_n_eval_examples = 1; | |||
int sample_1_of_n_eval_on_train_examples = 5; | |||