Browse Source

remove estimator.

tags/v0.13
Oceania2018 5 years ago
parent
commit
cad766bcd5
21 changed files with 7 additions and 754 deletions
  1. +0
    -61
      src/TensorFlowNET.Core/APIs/tf.estimator.cs
  2. +0
    -147
      src/TensorFlowNET.Core/Estimators/Estimator.cs
  3. +0
    -16
      src/TensorFlowNET.Core/Estimators/EstimatorSpec.cs
  4. +0
    -15
      src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs
  5. +0
    -16
      src/TensorFlowNET.Core/Estimators/EvalSpec.cs
  6. +0
    -61
      src/TensorFlowNET.Core/Estimators/Experimental.cs
  7. +0
    -11
      src/TensorFlowNET.Core/Estimators/Exporter.cs
  8. +0
    -14
      src/TensorFlowNET.Core/Estimators/FinalExporter.cs
  9. +0
    -89
      src/TensorFlowNET.Core/Estimators/HyperParams.cs
  10. +0
    -12
      src/TensorFlowNET.Core/Estimators/IEstimatorInputs.cs
  11. +0
    -7
      src/TensorFlowNET.Core/Estimators/README.md
  12. +0
    -103
      src/TensorFlowNET.Core/Estimators/RunConfig.cs
  13. +0
    -22
      src/TensorFlowNET.Core/Estimators/TrainSpec.cs
  14. +0
    -17
      src/TensorFlowNET.Core/Estimators/Training.cs
  15. +0
    -14
      src/TensorFlowNET.Core/Estimators/_Evaluator.cs
  16. +0
    -16
      src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs
  17. +0
    -14
      src/TensorFlowNET.Core/Estimators/_SavedModelExporter.cs
  18. +0
    -48
      src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs
  19. +0
    -5
      src/TensorFlowNET.Core/Operations/Operation.cs
  20. +7
    -2
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  21. +0
    -64
      test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs

+ 0
- 61
src/TensorFlowNET.Core/APIs/tf.estimator.cs View File

@@ -1,61 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using static Tensorflow.Binding;
using Tensorflow.Estimators;
using Tensorflow.Data;

namespace Tensorflow
{
public partial class tensorflow
{
public Estimator_Internal estimator { get; } = new Estimator_Internal();

public class Estimator_Internal
{
public Experimental experimental { get; } = new Experimental();
public Estimator<Thyp> Estimator<Thyp>(Func<IEstimatorInputs, EstimatorSpec> model_fn,
string model_dir = null,
RunConfig config = null,
Thyp hyperParams = default)
=> new Estimator<Thyp>(model_fn: model_fn, model_dir: model_dir, config: config, hyperParams: hyperParams);

public RunConfig RunConfig(string model_dir = null, int save_checkpoints_secs = 180)
=> new RunConfig(model_dir: model_dir, save_checkpoints_secs: save_checkpoints_secs);

public void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec)
=> Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);

public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
=> new TrainSpec(input_fn: input_fn, max_steps: max_steps);

/// <summary>
/// Create an `Exporter` to use with `tf.estimator.EvalSpec`.
/// </summary>
/// <param name="name"></param>
/// <param name="serving_input_receiver_fn"></param>
/// <param name="as_text"></param>
/// <returns></returns>
public FinalExporter FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false)
=> new FinalExporter(name: name, serving_input_receiver_fn: serving_input_receiver_fn,
as_text: as_text);

public EvalSpec EvalSpec(string name, Action input_fn, FinalExporter exporters)
=> new EvalSpec(name: name, input_fn: input_fn, exporters: exporters);
}
}
}

+ 0
- 147
src/TensorFlowNET.Core/Estimators/Estimator.cs View File

@@ -1,147 +0,0 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Text;
using Tensorflow.Data;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Estimators
{
/// <summary>
/// Estimator class to train and evaluate TensorFlow models.
/// </summary>
public class Estimator<Thyp> : IObjectLife
{
RunConfig _config;
public RunConfig config => _config;

ConfigProto _session_config;
public ConfigProto session_config => _session_config;

Func<IEstimatorInputs, EstimatorSpec> _model_fn;

Thyp _hyperParams;

public Estimator(Func<IEstimatorInputs, EstimatorSpec> model_fn,
string model_dir,
RunConfig config,
Thyp hyperParams)
{
_config = config;
_config.model_dir = config.model_dir ?? model_dir;
_session_config = config.session_config;
_model_fn = model_fn;
_hyperParams = hyperParams;
}

public Estimator<Thyp> train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null,
_NewCheckpointListenerForEvaluate<Thyp>[] saving_listeners = null)
{
if(max_steps > 0)
{
var start_step = _load_global_step_from_checkpoint_dir(_config.model_dir);
if (max_steps <= start_step)
{
Console.WriteLine("Skipping training since max_steps has already saved.");
return this;
}
}

var loss = _train_model(input_fn);
print($"Loss for final step: {loss}.");
return this;
}

private int _load_global_step_from_checkpoint_dir(string checkpoint_dir)
{
// var cp = tf.train.latest_checkpoint(checkpoint_dir);
// should use NewCheckpointReader (not implemented)
var cp = tf.train.get_checkpoint_state(checkpoint_dir);

return cp.AllModelCheckpointPaths.Count - 1;
}

private Tensor _train_model(Func<DatasetV1Adapter> input_fn)
{
return _train_model_default(input_fn);
}

private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn)
{
using (var g = tf.Graph().as_default())
{
var global_step_tensor = _create_and_assert_global_step(g);

// Skip creating a read variable if _create_and_assert_global_step
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
if (global_step_tensor != null)
TrainingUtil._get_or_create_global_step_read(g);

var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train");
}

throw new NotImplementedException("");
}

private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
var result = _call_input_fn(input_fn, mode);
return EstimatorUtil.parse_input_fn_result(result);
}

/// <summary>
/// Calls the input function.
/// </summary>
/// <param name="input_fn"></param>
/// <param name="mode"></param>
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode)
{
return input_fn();
}

private Tensor _create_and_assert_global_step(Graph graph)
{
var step = _create_global_step(graph);
Debug.Assert(step == tf.train.get_global_step(graph));
Debug.Assert(step.dtype.is_integer());
return step;
}

private RefVariable _create_global_step(Graph graph)
{
return tf.train.create_global_step(graph);
}

public string eval_dir(string name = null)
{
return Path.Combine(config.model_dir, string.IsNullOrEmpty(name) ? "eval" : $"eval_" + name);
}

public void __init__()
{
throw new NotImplementedException();
}

public void __enter__()
{
throw new NotImplementedException();
}

public void __del__()
{
throw new NotImplementedException();
}

public void __exit__()
{
throw new NotImplementedException();
}

public void Dispose()
{
}
}
}

+ 0
- 16
src/TensorFlowNET.Core/Estimators/EstimatorSpec.cs View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Estimators
{
public class EstimatorSpec
{
public EstimatorSpec(Operation train_op)
{

}
}
}

+ 0
- 15
src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs View File

@@ -1,15 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;

namespace Tensorflow.Estimators
{
public class EstimatorUtil
{
public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result)
{
throw new NotImplementedException("");
}
}
}

+ 0
- 16
src/TensorFlowNET.Core/Estimators/EvalSpec.cs View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class EvalSpec
{
string _name;

public EvalSpec(string name, Action input_fn, FinalExporter exporters)
{
_name = name;
}
}
}

+ 0
- 61
src/TensorFlowNET.Core/Estimators/Experimental.cs View File

@@ -1,61 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Estimators
{
public class Experimental
{
/// <summary>
/// Creates hook to stop if metric does not increase within given max steps.
/// </summary>
/// <typeparam name="Thyp">type of hyper parameters</typeparam>
/// <param name="estimator"></param>
/// <param name="metric_name"></param>
/// <param name="max_steps_without_increase"></param>
/// <param name="eval_dir"></param>
/// <param name="min_steps"></param>
/// <param name="run_every_secs"></param>
/// <param name="run_every_steps"></param>
/// <returns></returns>
public object stop_if_no_increase_hook<Thyp>(Estimator<Thyp> estimator,
string metric_name,
int max_steps_without_increase,
string eval_dir = null,
int min_steps = 0,
int run_every_secs = 60,
int run_every_steps = 0)
=> _stop_if_no_metric_improvement_hook(estimator: estimator,
metric_name: metric_name,
max_steps_without_increase: max_steps_without_increase,
eval_dir: eval_dir,
min_steps: min_steps,
run_every_secs: run_every_secs,
run_every_steps: run_every_steps);

private object _stop_if_no_metric_improvement_hook<Thyp>(Estimator<Thyp> estimator,
string metric_name,
int max_steps_without_increase,
string eval_dir = null,
int min_steps = 0,
int run_every_secs = 60,
int run_every_steps = 0)
{
eval_dir = eval_dir ?? estimator.eval_dir();
// var is_lhs_better = higher_is_better ? operator.gt: operator.lt;
Func<bool> stop_if_no_metric_improvement_fn = () =>
{
return false;
};

return make_early_stopping_hook();
}

public object make_early_stopping_hook()
{
throw new NotImplementedException("");
}
}
}

+ 0
- 11
src/TensorFlowNET.Core/Estimators/Exporter.cs View File

@@ -1,11 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public abstract class Exporter<Thyp>
{
public abstract void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path);
}
}

+ 0
- 14
src/TensorFlowNET.Core/Estimators/FinalExporter.cs View File

@@ -1,14 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class FinalExporter
{
public FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false)
{

}
}
}

+ 0
- 89
src/TensorFlowNET.Core/Estimators/HyperParams.cs View File

@@ -1,89 +0,0 @@
using System.IO;

namespace Tensorflow.Estimators
{
public class HyperParams
{
/// <summary>
/// root dir
/// </summary>
public string data_root_dir { get; set; }

/// <summary>
/// results dir
/// </summary>
public string result_dir { get; set; } = "results";

/// <summary>
/// model dir
/// </summary>
public string model_dir { get; set; } = "model";

public string eval_dir { get; set; } = "eval";

public string test_dir { get; set; } = "test";

public int dim { get; set; } = 300;
public float dropout { get; set; } = 0.5f;
public int num_oov_buckets { get; set; } = 1;
public int epochs { get; set; } = 25;
public int epoch_no_imprv { get; set; } = 3;
public int batch_size { get; set; } = 20;
public int buffer { get; set; } = 15000;
public int lstm_size { get; set; } = 100;
public string lr_method { get; set; } = "adam";
public float lr { get; set; } = 0.001f;
public float lr_decay { get; set; } = 0.9f;

/// <summary>
/// lstm on chars
/// </summary>
public int hidden_size_char { get; set; } = 100;

/// <summary>
/// lstm on word embeddings
/// </summary>
public int hidden_size_lstm { get; set; } = 300;

/// <summary>
/// is clipping
/// </summary>
public bool clip { get; set; } = false;

public string filepath_dev { get; set; }
public string filepath_test { get; set; }
public string filepath_train { get; set; }

public string filepath_words { get; set; }
public string filepath_chars { get; set; }
public string filepath_tags { get; set; }
public string filepath_glove { get; set; }

public HyperParams(string dataDir)
{
data_root_dir = dataDir;

if (string.IsNullOrEmpty(data_root_dir))
throw new ValueError("Please specifiy the root data directory");

if (!Directory.Exists(data_root_dir))
Directory.CreateDirectory(data_root_dir);

result_dir = Path.Combine(data_root_dir, result_dir);
if (!Directory.Exists(result_dir))
Directory.CreateDirectory(result_dir);

model_dir = Path.Combine(result_dir, model_dir);
if (!Directory.Exists(model_dir))
Directory.CreateDirectory(model_dir);

test_dir = Path.Combine(result_dir, test_dir);
if (!Directory.Exists(test_dir))
Directory.CreateDirectory(test_dir);

eval_dir = Path.Combine(result_dir, eval_dir);
if (!Directory.Exists(eval_dir))
Directory.CreateDirectory(eval_dir);
}
}
}

+ 0
- 12
src/TensorFlowNET.Core/Estimators/IEstimatorInputs.cs View File

@@ -1,12 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Estimators
{
public interface IEstimatorInputs
{
}
}

+ 0
- 7
src/TensorFlowNET.Core/Estimators/README.md View File

@@ -1,7 +0,0 @@
# TensorFlow Estimator

TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model.



https://github.com/tensorflow/estimator

+ 0
- 103
src/TensorFlowNET.Core/Estimators/RunConfig.cs View File

@@ -1,103 +0,0 @@
using System;

namespace Tensorflow.Estimators
{
public class RunConfig
{
// A list of the property names in RunConfig that the user is allowed to change.
private static readonly string[] _DEFAULT_REPLACEABLE_LIST = new []
{
"model_dir",
"tf_random_seed",
"save_summary_steps",
"save_checkpoints_steps",
"save_checkpoints_secs",
"session_config",
"keep_checkpoint_max",
"keep_checkpoint_every_n_hours",
"log_step_count_steps",
"train_distribute",
"device_fn",
"protocol",
"eval_distribute",
"experimental_distribute",
"experimental_max_worker_delay_secs",
"session_creation_timeout_secs"
};


#region const values

private const string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.";
private const string _TF_CONFIG_ENV = "TF_CONFIG";
private const string _TASK_ENV_KEY = "task";
private const string _TASK_TYPE_KEY = "type";
private const string _TASK_ID_KEY = "index";
private const string _CLUSTER_KEY = "cluster";
private const string _SERVICE_KEY = "service";
private const string _SESSION_MASTER_KEY = "session_master";
private const string _EVAL_SESSION_MASTER_KEY = "eval_session_master";
private const string _MODEL_DIR_KEY = "model_dir";
private const string _LOCAL_MASTER = "";
private const string _GRPC_SCHEME = "grpc://";

#endregion

public string model_dir { get; set; }
public ConfigProto session_config { get; set; }
public int? tf_random_seed { get; set; }
public int save_summary_steps { get; set; } = 100;
public int save_checkpoints_steps { get; set; }
public int save_checkpoints_secs { get; set; } = 600;
public int keep_checkpoint_max { get; set; } = 5;
public int keep_checkpoint_every_n_hours { get; set; } = 10000;
public int log_step_count_steps{ get; set; } = 100;
public object train_distribute { get; set; }
public object device_fn { get; set; }
public object protocol { get; set; }
public object eval_distribute { get; set; }
public object experimental_distribute { get; set; }
public object experimental_max_worker_delay_secs { get; set; }
public int session_creation_timeout_secs { get; set; } = 7200;
public object service { get; set; }

public RunConfig()
{
Initialize();
}

public RunConfig(string model_dir,
int save_checkpoints_secs)
{
this.model_dir = model_dir;
this.save_checkpoints_secs = save_checkpoints_secs;
Initialize();
}

public RunConfig(
string model_dir = null,
int? tf_random_seed = null,
int save_summary_steps=100,
object save_checkpoints_steps = null, // _USE_DEFAULT
object save_checkpoints_secs = null, // _USE_DEFAULT
object session_config = null,
int keep_checkpoint_max = 5,
int keep_checkpoint_every_n_hours = 10000,
int log_step_count_steps = 100,
object train_distribute = null,
object device_fn = null,
object protocol = null,
object eval_distribute = null,
object experimental_distribute = null,
object experimental_max_worker_delay_secs = null,
int session_creation_timeout_secs = 7200)
{
this.model_dir = model_dir;
Initialize();
}

private void Initialize()
{
}
}
}

+ 0
- 22
src/TensorFlowNET.Core/Estimators/TrainSpec.cs View File

@@ -1,22 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;

namespace Tensorflow.Estimators
{
public class TrainSpec
{
int _max_steps;
public int max_steps => _max_steps;

Func<DatasetV1Adapter> _input_fn;
public Func<DatasetV1Adapter> input_fn => _input_fn;

public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps)
{
_max_steps = max_steps;
_input_fn = input_fn;
}
}
}

+ 0
- 17
src/TensorFlowNET.Core/Estimators/Training.cs View File

@@ -1,17 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class Training
{
public static void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec)
{
var executor = new _TrainingExecutor<Thyp>(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec);
var config = estimator.config;

executor.run();
}
}
}

+ 0
- 14
src/TensorFlowNET.Core/Estimators/_Evaluator.cs View File

@@ -1,14 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class _Evaluator<Thyp>
{
public _Evaluator(Estimator<Thyp> estimator, EvalSpec eval_spec, int max_training_steps)
{

}
}
}

+ 0
- 16
src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class _NewCheckpointListenerForEvaluate<Thyp>
{
_Evaluator<Thyp> _evaluator;

public _NewCheckpointListenerForEvaluate(_Evaluator<Thyp> evaluator, int eval_throttle_secs)
{
_evaluator = evaluator;
}
}
}

+ 0
- 14
src/TensorFlowNET.Core/Estimators/_SavedModelExporter.cs View File

@@ -1,14 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
public class _SavedModelExporter<Thyp> : Exporter<Thyp>
{
public override void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path)
{
}
}
}

+ 0
- 48
src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs View File

@@ -1,48 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Estimators
{
/// <summary>
/// The executor to run `Estimator` training and evaluation.
/// </summary>
internal class _TrainingExecutor<Thyp>
{
Estimator<Thyp> _estimator;
EvalSpec _eval_spec;
TrainSpec _train_spec;

public _TrainingExecutor(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec)
{
_estimator = estimator;
_train_spec = train_spec;
_eval_spec = eval_spec;
}

public void run()
{
var config = _estimator.config;
Console.WriteLine("Running training and evaluation locally (non-distributed).");
run_local();
}

/// <summary>
/// Runs training and evaluation locally (non-distributed).
/// </summary>
private void run_local()
{
var train_hooks = new Action[0];
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " +
"after every checkpoint. Checkpoint frequency is determined " +
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " +
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}.");
var evaluator = new _Evaluator<Thyp>(_estimator, _eval_spec, _train_spec.max_steps);
var saving_listeners = new _NewCheckpointListenerForEvaluate<Thyp>[0];
_estimator.train(input_fn: _train_spec.input_fn,
max_steps: _train_spec.max_steps,
hooks: train_hooks,
saving_listeners: saving_listeners);
}
}
}

+ 0
- 5
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -186,11 +186,6 @@ namespace Tensorflow
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
if (node_def.Name.Equals("learn_rate/cond/pred_id"))
{
}
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_is_stateful = op_def.IsStateful;


+ 7
- 2
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -42,8 +42,14 @@ https://tensorflownet.readthedocs.io</Description>
</PropertyGroup>

<ItemGroup>
<Compile Remove="Distribute\**" />
<Compile Remove="Models\**" />
<Compile Remove="runtimes\**" />
<EmbeddedResource Remove="Distribute\**" />
<EmbeddedResource Remove="Models\**" />
<EmbeddedResource Remove="runtimes\**" />
<None Remove="Distribute\**" />
<None Remove="Models\**" />
<None Remove="runtimes\**" />
<None Include="..\..\LICENSE">
<Pack>True</Pack>
@@ -61,8 +67,7 @@ https://tensorflownet.readthedocs.io</Description>
</ItemGroup>

<ItemGroup>
<Folder Include="Distribute\" />
<Folder Include="Estimators\" />
<Folder Include="Keras\Initializers\" />
<Folder Include="Models\" />
</ItemGroup>
</Project>

+ 0
- 64
test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs View File

@@ -1,64 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using Tensorflow.Estimators;

namespace TensorFlowNET.UnitTest.Estimators
{
/// <summary>
/// estimator/tensorflow_estimator/python/estimator/run_config_test.py
/// </summary>
[TestClass]
public class RunConfigTest
{
private static readonly string _TEST_DIR = "test_dir";
private static readonly string _MASTER = "master_";
private static readonly string _NOT_SUPPORTED_REPLACE_PROPERTY_MSG = "Replacing .*is not supported";
private static readonly string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.";
private static readonly string _MODEL_DIR_ERR = "model_dir should be non-empty";
private static readonly string _MODEL_DIR_TF_CONFIG_ERR = "model_dir in TF_CONFIG should be non-empty";
private static readonly string _MODEL_DIR_MISMATCH_ERR = "`model_dir` provided in RunConfig construct, if set, must have the same value as the model_dir in TF_CONFIG. ";
private static readonly string _SAVE_SUMMARY_STEPS_ERR = "save_summary_steps should be >= 0";
private static readonly string _SAVE_CKPT_STEPS_ERR = "save_checkpoints_steps should be >= 0";
private static readonly string _SAVE_CKPT_SECS_ERR = "save_checkpoints_secs should be >= 0";
private static readonly string _SESSION_CONFIG_ERR = "session_config must be instance of ConfigProto";
private static readonly string _KEEP_CKPT_MAX_ERR = "keep_checkpoint_max should be >= 0";
private static readonly string _KEEP_CKPT_HOURS_ERR = "keep_checkpoint_every_n_hours should be > 0";
private static readonly string _TF_RANDOM_SEED_ERR = "tf_random_seed must be integer";
private static readonly string _DEVICE_FN_ERR = "device_fn must be callable with exactly one argument \"op\".";
private static readonly string _ONE_CHIEF_ERR = "The \"cluster\" in TF_CONFIG must have only one \"chief\" node.";
private static readonly string _ONE_MASTER_ERR = "The \"cluster\" in TF_CONFIG must have only one \"master\" node.";
private static readonly string _MISSING_CHIEF_ERR = "If \"cluster\" is set .* it must have one \"chief\" node";
private static readonly string _MISSING_TASK_TYPE_ERR = "If \"cluster\" is set .* task type must be set";
private static readonly string _MISSING_TASK_ID_ERR = "If \"cluster\" is set .* task index must be set";
private static readonly string _INVALID_TASK_INDEX_ERR = "is not a valid task_id";
private static readonly string _NEGATIVE_TASK_INDEX_ERR = "Task index must be non-negative number.";
private static readonly string _INVALID_TASK_TYPE_ERR = "is not a valid task_type";
private static readonly string _INVALID_TASK_TYPE_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task type must be WORKER.";
private static readonly string _INVALID_TASK_INDEX_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task index must be 0.";
private static readonly string _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, task_type `evaluator` is not supported.";
private static readonly string _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, job `chief` is not supported.";
private static readonly string _INVALID_SERVICE_TYPE_ERR = "If \"service\" is set in TF_CONFIG, it must be a dict. Given";
private static readonly string _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = "experimental_max_worker_delay_secs must be an integer if set.";
private static readonly string _SESSION_CREATION_TIMEOUT_SECS_ERR = "session_creation_timeout_secs should be > 0";

[TestMethod]
public void test_default_property_values()
{
var config = new RunConfig();

Assert.IsNull(config.model_dir);
Assert.IsNull(config.session_config);
Assert.IsNull(config.tf_random_seed);
Assert.AreEqual(100, config.save_summary_steps);
Assert.AreEqual(600, config.save_checkpoints_secs);
Assert.AreEqual(5, config.keep_checkpoint_max);
Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours);
Assert.IsNull(config.service);
Assert.IsNull(config.device_fn);
Assert.IsNull(config.experimental_max_worker_delay_secs);
Assert.AreEqual(7200, config.session_creation_timeout_secs);
}
}
}

Loading…
Cancel
Save