Browse Source

_TrainingExecutor

tags/v0.12
Oceania2018 6 years ago
parent
commit
078aa49578
6 changed files with 56 additions and 2 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Estimators/Training.cs
  3. +13
    -1
      src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
  4. +10
    -0
      src/TensorFlowNET.Core/Train/SessionRunArgs.cs
  5. +17
    -1
      src/TensorFlowNET.Core/Train/SessionRunContext.cs
  6. +10
    -0
      src/TensorFlowNET.Core/Train/SessionRunValues.cs

+ 4
- 0
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -11,7 +11,11 @@ namespace Tensorflow.Estimators
public class Estimator : IObjectLife public class Estimator : IObjectLife
{ {
RunConfig _config; RunConfig _config;
public RunConfig config => _config;

ConfigProto _session_config; ConfigProto _session_config;
public ConfigProto session_config => _session_config;

string _model_dir; string _model_dir;


public Estimator(RunConfig config) public Estimator(RunConfig config)


+ 2
- 0
src/TensorFlowNET.Core/Estimators/Training.cs View File

@@ -9,6 +9,8 @@ namespace Tensorflow.Estimators
public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec)
{ {
var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);
var config = estimator.config;

executor.run(); executor.run();
} }
} }


+ 13
- 1
src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs View File

@@ -4,7 +4,10 @@ using System.Text;


namespace Tensorflow.Estimators namespace Tensorflow.Estimators
{ {
public class _TrainingExecutor
/// <summary>
/// The executor to run `Estimator` training and evaluation.
/// </summary>
internal class _TrainingExecutor
{ {
Estimator _estimator; Estimator _estimator;
EvalSpec _eval_spec; EvalSpec _eval_spec;
@@ -17,11 +20,20 @@ namespace Tensorflow.Estimators


public void run() public void run()
{ {
var config = _estimator.config;
Console.WriteLine("Running training and evaluation locally (non-distributed).");
run_local(); run_local();
} }


/// <summary>
/// Runs training and evaluation locally (non-distributed).
/// </summary>
private void run_local() private void run_local()
{ {
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
"after every checkpoint. Checkpoint frequency is determined " +
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps);
/*_estimator.train(input_fn: _train_spec.input_fn, /*_estimator.train(input_fn: _train_spec.input_fn,
max_steps: _train_spec.max_steps, max_steps: _train_spec.max_steps,


+ 10
- 0
src/TensorFlowNET.Core/Train/SessionRunArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Train
{
public class SessionRunArgs
{
}
}

+ 17
- 1
src/TensorFlowNET.Core/Train/SessionRunContext.cs View File

@@ -6,9 +6,25 @@ namespace Tensorflow.Train
{ {
public class SessionRunContext public class SessionRunContext
{ {
public SessionRunContext(Session session)
SessionRunArgs _original_args;
public SessionRunArgs original_args => _original_args;

Session _session;
public Session session => _session;

bool _stop_requested;
public bool stop_requested => _stop_requested;

public SessionRunContext(SessionRunArgs original_args, Session session)
{ {
_original_args = original_args;
_session = session;
_stop_requested = false;
}


public void request_stop()
{
_stop_requested = true;
} }
} }
} }

+ 10
- 0
src/TensorFlowNET.Core/Train/SessionRunValues.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Train
{
public class SessionRunValues
{
}
}

Loading…
Cancel
Save