@@ -11,7 +11,11 @@ namespace Tensorflow.Estimators | |||||
public class Estimator : IObjectLife | public class Estimator : IObjectLife | ||||
{ | { | ||||
RunConfig _config; | RunConfig _config; | ||||
public RunConfig config => _config; | |||||
ConfigProto _session_config; | ConfigProto _session_config; | ||||
public ConfigProto session_config => _session_config; | |||||
string _model_dir; | string _model_dir; | ||||
public Estimator(RunConfig config) | public Estimator(RunConfig config) | ||||
@@ -9,6 +9,8 @@ namespace Tensorflow.Estimators | |||||
public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | ||||
{ | { | ||||
var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | ||||
var config = estimator.config; | |||||
executor.run(); | executor.run(); | ||||
} | } | ||||
} | } | ||||
@@ -4,7 +4,10 @@ using System.Text; | |||||
namespace Tensorflow.Estimators | namespace Tensorflow.Estimators | ||||
{ | { | ||||
public class _TrainingExecutor | |||||
/// <summary> | |||||
/// The executor to run `Estimator` training and evaluation. | |||||
/// </summary> | |||||
internal class _TrainingExecutor | |||||
{ | { | ||||
Estimator _estimator; | Estimator _estimator; | ||||
EvalSpec _eval_spec; | EvalSpec _eval_spec; | ||||
@@ -17,11 +20,20 @@ namespace Tensorflow.Estimators | |||||
public void run() | public void run() | ||||
{ | { | ||||
var config = _estimator.config; | |||||
Console.WriteLine("Running training and evaluation locally (non-distributed)."); | |||||
run_local(); | run_local(); | ||||
} | } | ||||
/// <summary> | |||||
/// Runs training and evaluation locally (non-distributed). | |||||
/// </summary> | |||||
private void run_local() | private void run_local() | ||||
{ | { | ||||
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||||
"after every checkpoint. Checkpoint frequency is determined " + | |||||
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||||
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||||
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); | var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); | ||||
/*_estimator.train(input_fn: _train_spec.input_fn, | /*_estimator.train(input_fn: _train_spec.input_fn, | ||||
max_steps: _train_spec.max_steps, | max_steps: _train_spec.max_steps, | ||||
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Train | |||||
{ | |||||
public class SessionRunArgs | |||||
{ | |||||
} | |||||
} |
@@ -6,9 +6,25 @@ namespace Tensorflow.Train | |||||
{ | { | ||||
public class SessionRunContext | public class SessionRunContext | ||||
{ | { | ||||
public SessionRunContext(Session session) | |||||
SessionRunArgs _original_args; | |||||
public SessionRunArgs original_args => _original_args; | |||||
Session _session; | |||||
public Session session => _session; | |||||
bool _stop_requested; | |||||
public bool stop_requested => _stop_requested; | |||||
public SessionRunContext(SessionRunArgs original_args, Session session) | |||||
{ | { | ||||
_original_args = original_args; | |||||
_session = session; | |||||
_stop_requested = false; | |||||
} | |||||
public void request_stop() | |||||
{ | |||||
_stop_requested = true; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Train | |||||
{ | |||||
public class SessionRunValues | |||||
{ | |||||
} | |||||
} |