Browse Source

_TrainingExecutor

tags/v0.12
Oceania2018 6 years ago
parent
commit
078aa49578
6 changed files with 56 additions and 2 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Estimators/Training.cs
  3. +13
    -1
      src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
  4. +10
    -0
      src/TensorFlowNET.Core/Train/SessionRunArgs.cs
  5. +17
    -1
      src/TensorFlowNET.Core/Train/SessionRunContext.cs
  6. +10
    -0
      src/TensorFlowNET.Core/Train/SessionRunValues.cs

+ 4
- 0
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -11,7 +11,11 @@ namespace Tensorflow.Estimators
public class Estimator : IObjectLife
{
RunConfig _config;
public RunConfig config => _config;

ConfigProto _session_config;
public ConfigProto session_config => _session_config;

string _model_dir;

public Estimator(RunConfig config)


+ 2
- 0
src/TensorFlowNET.Core/Estimators/Training.cs View File

@@ -9,6 +9,8 @@ namespace Tensorflow.Estimators
public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec)
{
var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);
var config = estimator.config;

executor.run();
}
}


+ 13
- 1
src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs View File

@@ -4,7 +4,10 @@ using System.Text;

namespace Tensorflow.Estimators
{
public class _TrainingExecutor
/// <summary>
/// The executor to run `Estimator` training and evaluation.
/// </summary>
internal class _TrainingExecutor
{
Estimator _estimator;
EvalSpec _eval_spec;
@@ -17,11 +20,20 @@ namespace Tensorflow.Estimators

public void run()
{
var config = _estimator.config;
Console.WriteLine("Running training and evaluation locally (non-distributed).");
run_local();
}

/// <summary>
/// Runs training and evaluation locally (non-distributed).
/// </summary>
private void run_local()
{
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
"after every checkpoint. Checkpoint frequency is determined " +
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps);
/*_estimator.train(input_fn: _train_spec.input_fn,
max_steps: _train_spec.max_steps,


+ 10
- 0
src/TensorFlowNET.Core/Train/SessionRunArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Train
{
public class SessionRunArgs
{
}
}

+ 17
- 1
src/TensorFlowNET.Core/Train/SessionRunContext.cs View File

@@ -6,9 +6,25 @@ namespace Tensorflow.Train
{
public class SessionRunContext
{
public SessionRunContext(Session session)
SessionRunArgs _original_args;
public SessionRunArgs original_args => _original_args;

Session _session;
public Session session => _session;

bool _stop_requested;
public bool stop_requested => _stop_requested;

public SessionRunContext(SessionRunArgs original_args, Session session)
{
_original_args = original_args;
_session = session;
_stop_requested = false;
}

public void request_stop()
{
_stop_requested = true;
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Train/SessionRunValues.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Train
{
public class SessionRunValues
{
}
}

Loading…
Cancel
Save