Browse Source

latest_checkpoint

tags/v0.12
Oceania2018 6 years ago
parent
commit
bffb9cdeb4
9 changed files with 105 additions and 13 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +16
    -1
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  3. +3
    -1
      src/TensorFlowNET.Core/Estimators/EvalSpec.cs
  4. +7
    -2
      src/TensorFlowNET.Core/Estimators/TrainSpec.cs
  5. +6
    -0
      src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs
  6. +4
    -2
      src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
  7. +50
    -0
      src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
  8. +12
    -6
      src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs
  9. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs

+ 6
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -57,6 +57,12 @@ namespace Tensorflow
clear_devices: clear_devices,
clear_extraneous_savers: clear_extraneous_savers,
strip_default_attrs: strip_default_attrs);

public string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
=> checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename);

public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename);
}
}
}

+ 16
- 1
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -18,20 +18,35 @@ namespace Tensorflow.Estimators

string _model_dir;

Action _model_fn;

public Estimator(Action model_fn, RunConfig config)
{
_config = config;
_model_dir = _config.model_dir;
_session_config = _config.session_config;
_model_fn = model_fn;
}

public Estimator train(Action input_fn, int max_steps = 1,
public Estimator train(Action input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate[] saving_listeners = null)
{
if(max_steps > 0)
{
var start_step = _load_global_step_from_checkpoint_dir(_model_dir);
}

_train_model();
throw new NotImplementedException("");
}

private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
{
var cp = tf.train.latest_checkpoint(checkpoint_dir);

return 0;
}

private void _train_model()
{
_train_model_default();


+ 3
- 1
src/TensorFlowNET.Core/Estimators/EvalSpec.cs View File

@@ -6,9 +6,11 @@ namespace Tensorflow.Estimators
{
public class EvalSpec
{
string _name;

public EvalSpec(string name, Action input_fn, FinalExporter exporters)
{
_name = name;
}
}
}

+ 7
- 2
src/TensorFlowNET.Core/Estimators/TrainSpec.cs View File

@@ -6,11 +6,16 @@ namespace Tensorflow.Estimators
{
public class TrainSpec
{
public int max_steps { get; set; }
int _max_steps;
public int max_steps => _max_steps;

Action _input_fn;
public Action input_fn => _input_fn;

public TrainSpec(Action input_fn, int max_steps)
{
this.max_steps = max_steps;
_max_steps = max_steps;
_input_fn = input_fn;
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs View File

@@ -6,5 +6,11 @@ namespace Tensorflow.Estimators
{
public class _NewCheckpointListenerForEvaluate
{
_Evaluator _evaluator;

public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs)
{
_evaluator = evaluator;
}
}
}

+ 4
- 2
src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs View File

@@ -32,15 +32,17 @@ namespace Tensorflow.Estimators
/// </summary>
private void run_local()
{
var train_hooks = new Action[0];
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
"after every checkpoint. Checkpoint frequency is determined " +
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps);
/*_estimator.train(input_fn: _train_spec.input_fn,
var saving_listeners = new _NewCheckpointListenerForEvaluate[0];
_estimator.train(input_fn: _train_spec.input_fn,
max_steps: _train_spec.max_steps,
hooks: train_hooks,
saving_listeners: saving_listeners);*/
saving_listeners: saving_listeners);
}
}
}

+ 50
- 0
src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using static Tensorflow.SaverDef.Types;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -144,5 +145,54 @@ namespace Tensorflow
return prefix + ".index";
return prefix;
}

/// <summary>
/// Finds the filename of latest saved checkpoint file.
/// </summary>
/// <param name="checkpoint_dir"></param>
/// <param name="latest_filename"></param>
/// <returns></returns>
public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null)
{
// Pick the latest checkpoint based on checkpoint state.
var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename);
if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
{
// Look for either a V2 path or a V1 path, with priority for V2.
var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2);
var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1);
if (File.Exists(v2_path) || File.Exists(v1_path))
return ckpt.ModelCheckpointPath;
else
throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}");
}
return null;
}

public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
{
var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename);
if (File.Exists(coord_checkpoint_filename))
{
var file_content = File.ReadAllBytes(coord_checkpoint_filename);
var ckpt = CheckpointState.Parser.ParseFrom(file_content);
if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath))
throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}");
// For relative model_checkpoint_path and all_model_checkpoint_paths,
// prepend checkpoint_dir.
if (!Path.IsPathRooted(ckpt.ModelCheckpointPath))
ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath);
foreach(var i in range(len(ckpt.AllModelCheckpointPaths)))
{
var p = ckpt.AllModelCheckpointPaths[i];
if (!Path.IsPathRooted(p))
ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p);
}

return ckpt;
}

return null;
}
}
}

+ 12
- 6
src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs View File

@@ -19,19 +19,28 @@ namespace Tensorflow.Models.ObjectDetection
int sample_1_of_n_eval_on_train_examples = 1)
{
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);

// Create the input functions for TRAIN/EVAL/PREDICT.
Action train_input_fn = () => { };

var eval_input_configs = config.EvalInputReader;

var eval_input_fns = new Action[eval_input_configs.Count];
var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray();
Action eval_on_train_input_fn = () => { };
Action predict_input_fn = () => { };
Action model_fn = () => { };
var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config);

return new TrainAndEvalDict
{
estimator = estimator,
train_steps = train_steps,
train_input_fn = train_input_fn,
eval_input_fns = eval_input_fns,
eval_input_names = eval_input_names
eval_input_names = eval_input_names,
eval_on_train_input_fn = eval_on_train_input_fn,
predict_input_fn = predict_input_fn,
train_steps = train_steps
};
}

@@ -46,10 +55,7 @@ namespace Tensorflow.Models.ObjectDetection
.Select(x => x.ToString())
.ToArray();

var eval_specs = new List<EvalSpec>()
{
new EvalSpec("", null, null) // for test.
};
var eval_specs = new List<EvalSpec>();
foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList()))
{
var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}";


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs View File

@@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection

string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model";
string pipeline_config_path = "ObjectDetection/Models/faster_rcnn_resnet101_voc07.config";
int num_train_steps = 1;
int num_train_steps = 50;
int sample_1_of_n_eval_examples = 1;
int sample_1_of_n_eval_on_train_examples = 5;



Loading…
Cancel
Save