@@ -1,61 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Estimators; | |||
using Tensorflow.Data; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public Estimator_Internal estimator { get; } = new Estimator_Internal(); | |||
public class Estimator_Internal | |||
{ | |||
public Experimental experimental { get; } = new Experimental(); | |||
public Estimator<Thyp> Estimator<Thyp>(Func<IEstimatorInputs, EstimatorSpec> model_fn, | |||
string model_dir = null, | |||
RunConfig config = null, | |||
Thyp hyperParams = default) | |||
=> new Estimator<Thyp>(model_fn: model_fn, model_dir: model_dir, config: config, hyperParams: hyperParams); | |||
public RunConfig RunConfig(string model_dir = null, int save_checkpoints_secs = 180) | |||
=> new RunConfig(model_dir: model_dir, save_checkpoints_secs: save_checkpoints_secs); | |||
public void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||
=> Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||
public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||
=> new TrainSpec(input_fn: input_fn, max_steps: max_steps); | |||
/// <summary> | |||
/// Create an `Exporter` to use with `tf.estimator.EvalSpec`. | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <param name="serving_input_receiver_fn"></param> | |||
/// <param name="as_text"></param> | |||
/// <returns></returns> | |||
public FinalExporter FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||
=> new FinalExporter(name: name, serving_input_receiver_fn: serving_input_receiver_fn, | |||
as_text: as_text); | |||
public EvalSpec EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||
=> new EvalSpec(name: name, input_fn: input_fn, exporters: exporters); | |||
} | |||
} | |||
} |
@@ -1,147 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Text; | |||
using Tensorflow.Data; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Estimators | |||
{ | |||
/// <summary> | |||
/// Estimator class to train and evaluate TensorFlow models. | |||
/// </summary> | |||
public class Estimator<Thyp> : IObjectLife | |||
{ | |||
RunConfig _config; | |||
public RunConfig config => _config; | |||
ConfigProto _session_config; | |||
public ConfigProto session_config => _session_config; | |||
Func<IEstimatorInputs, EstimatorSpec> _model_fn; | |||
Thyp _hyperParams; | |||
public Estimator(Func<IEstimatorInputs, EstimatorSpec> model_fn, | |||
string model_dir, | |||
RunConfig config, | |||
Thyp hyperParams) | |||
{ | |||
_config = config; | |||
_config.model_dir = config.model_dir ?? model_dir; | |||
_session_config = config.session_config; | |||
_model_fn = model_fn; | |||
_hyperParams = hyperParams; | |||
} | |||
public Estimator<Thyp> train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null, | |||
_NewCheckpointListenerForEvaluate<Thyp>[] saving_listeners = null) | |||
{ | |||
if(max_steps > 0) | |||
{ | |||
var start_step = _load_global_step_from_checkpoint_dir(_config.model_dir); | |||
if (max_steps <= start_step) | |||
{ | |||
Console.WriteLine("Skipping training since max_steps has already saved."); | |||
return this; | |||
} | |||
} | |||
var loss = _train_model(input_fn); | |||
print($"Loss for final step: {loss}."); | |||
return this; | |||
} | |||
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||
{ | |||
// var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
// should use NewCheckpointReader (not implemented) | |||
var cp = tf.train.get_checkpoint_state(checkpoint_dir); | |||
return cp.AllModelCheckpointPaths.Count - 1; | |||
} | |||
private Tensor _train_model(Func<DatasetV1Adapter> input_fn) | |||
{ | |||
return _train_model_default(input_fn); | |||
} | |||
private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn) | |||
{ | |||
using (var g = tf.Graph().as_default()) | |||
{ | |||
var global_step_tensor = _create_and_assert_global_step(g); | |||
// Skip creating a read variable if _create_and_assert_global_step | |||
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator). | |||
if (global_step_tensor != null) | |||
TrainingUtil._get_or_create_global_step_read(g); | |||
var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||
{ | |||
var result = _call_input_fn(input_fn, mode); | |||
return EstimatorUtil.parse_input_fn_result(result); | |||
} | |||
/// <summary> | |||
/// Calls the input function. | |||
/// </summary> | |||
/// <param name="input_fn"></param> | |||
/// <param name="mode"></param> | |||
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||
{ | |||
return input_fn(); | |||
} | |||
private Tensor _create_and_assert_global_step(Graph graph) | |||
{ | |||
var step = _create_global_step(graph); | |||
Debug.Assert(step == tf.train.get_global_step(graph)); | |||
Debug.Assert(step.dtype.is_integer()); | |||
return step; | |||
} | |||
private RefVariable _create_global_step(Graph graph) | |||
{ | |||
return tf.train.create_global_step(graph); | |||
} | |||
public string eval_dir(string name = null) | |||
{ | |||
return Path.Combine(config.model_dir, string.IsNullOrEmpty(name) ? "eval" : $"eval_" + name); | |||
} | |||
public void __init__() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public void __enter__() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public void __del__() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public void __exit__() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -1,16 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class EstimatorSpec | |||
{ | |||
public EstimatorSpec(Operation train_op) | |||
{ | |||
} | |||
} | |||
} |
@@ -1,15 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Data; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class EstimatorUtil | |||
{ | |||
public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -1,16 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class EvalSpec | |||
{ | |||
string _name; | |||
public EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||
{ | |||
_name = name; | |||
} | |||
} | |||
} |
@@ -1,61 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class Experimental | |||
{ | |||
/// <summary> | |||
/// Creates hook to stop if metric does not increase within given max steps. | |||
/// </summary> | |||
/// <typeparam name="Thyp">type of hyper parameters</typeparam> | |||
/// <param name="estimator"></param> | |||
/// <param name="metric_name"></param> | |||
/// <param name="max_steps_without_increase"></param> | |||
/// <param name="eval_dir"></param> | |||
/// <param name="min_steps"></param> | |||
/// <param name="run_every_secs"></param> | |||
/// <param name="run_every_steps"></param> | |||
/// <returns></returns> | |||
public object stop_if_no_increase_hook<Thyp>(Estimator<Thyp> estimator, | |||
string metric_name, | |||
int max_steps_without_increase, | |||
string eval_dir = null, | |||
int min_steps = 0, | |||
int run_every_secs = 60, | |||
int run_every_steps = 0) | |||
=> _stop_if_no_metric_improvement_hook(estimator: estimator, | |||
metric_name: metric_name, | |||
max_steps_without_increase: max_steps_without_increase, | |||
eval_dir: eval_dir, | |||
min_steps: min_steps, | |||
run_every_secs: run_every_secs, | |||
run_every_steps: run_every_steps); | |||
private object _stop_if_no_metric_improvement_hook<Thyp>(Estimator<Thyp> estimator, | |||
string metric_name, | |||
int max_steps_without_increase, | |||
string eval_dir = null, | |||
int min_steps = 0, | |||
int run_every_secs = 60, | |||
int run_every_steps = 0) | |||
{ | |||
eval_dir = eval_dir ?? estimator.eval_dir(); | |||
// var is_lhs_better = higher_is_better ? operator.gt: operator.lt; | |||
Func<bool> stop_if_no_metric_improvement_fn = () => | |||
{ | |||
return false; | |||
}; | |||
return make_early_stopping_hook(); | |||
} | |||
public object make_early_stopping_hook() | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -1,11 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public abstract class Exporter<Thyp> | |||
{ | |||
public abstract void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path); | |||
} | |||
} |
@@ -1,14 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class FinalExporter | |||
{ | |||
public FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||
{ | |||
} | |||
} | |||
} |
@@ -1,89 +0,0 @@ | |||
using System.IO; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class HyperParams | |||
{ | |||
/// <summary> | |||
/// root dir | |||
/// </summary> | |||
public string data_root_dir { get; set; } | |||
/// <summary> | |||
/// results dir | |||
/// </summary> | |||
public string result_dir { get; set; } = "results"; | |||
/// <summary> | |||
/// model dir | |||
/// </summary> | |||
public string model_dir { get; set; } = "model"; | |||
public string eval_dir { get; set; } = "eval"; | |||
public string test_dir { get; set; } = "test"; | |||
public int dim { get; set; } = 300; | |||
public float dropout { get; set; } = 0.5f; | |||
public int num_oov_buckets { get; set; } = 1; | |||
public int epochs { get; set; } = 25; | |||
public int epoch_no_imprv { get; set; } = 3; | |||
public int batch_size { get; set; } = 20; | |||
public int buffer { get; set; } = 15000; | |||
public int lstm_size { get; set; } = 100; | |||
public string lr_method { get; set; } = "adam"; | |||
public float lr { get; set; } = 0.001f; | |||
public float lr_decay { get; set; } = 0.9f; | |||
/// <summary> | |||
/// lstm on chars | |||
/// </summary> | |||
public int hidden_size_char { get; set; } = 100; | |||
/// <summary> | |||
/// lstm on word embeddings | |||
/// </summary> | |||
public int hidden_size_lstm { get; set; } = 300; | |||
/// <summary> | |||
/// is clipping | |||
/// </summary> | |||
public bool clip { get; set; } = false; | |||
public string filepath_dev { get; set; } | |||
public string filepath_test { get; set; } | |||
public string filepath_train { get; set; } | |||
public string filepath_words { get; set; } | |||
public string filepath_chars { get; set; } | |||
public string filepath_tags { get; set; } | |||
public string filepath_glove { get; set; } | |||
public HyperParams(string dataDir) | |||
{ | |||
data_root_dir = dataDir; | |||
if (string.IsNullOrEmpty(data_root_dir)) | |||
throw new ValueError("Please specifiy the root data directory"); | |||
if (!Directory.Exists(data_root_dir)) | |||
Directory.CreateDirectory(data_root_dir); | |||
result_dir = Path.Combine(data_root_dir, result_dir); | |||
if (!Directory.Exists(result_dir)) | |||
Directory.CreateDirectory(result_dir); | |||
model_dir = Path.Combine(result_dir, model_dir); | |||
if (!Directory.Exists(model_dir)) | |||
Directory.CreateDirectory(model_dir); | |||
test_dir = Path.Combine(result_dir, test_dir); | |||
if (!Directory.Exists(test_dir)) | |||
Directory.CreateDirectory(test_dir); | |||
eval_dir = Path.Combine(result_dir, eval_dir); | |||
if (!Directory.Exists(eval_dir)) | |||
Directory.CreateDirectory(eval_dir); | |||
} | |||
} | |||
} |
@@ -1,12 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public interface IEstimatorInputs | |||
{ | |||
} | |||
} |
@@ -1,7 +0,0 @@ | |||
# TensorFlow Estimator | |||
TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. | |||
https://github.com/tensorflow/estimator |
@@ -1,103 +0,0 @@ | |||
using System; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class RunConfig | |||
{ | |||
// A list of the property names in RunConfig that the user is allowed to change. | |||
private static readonly string[] _DEFAULT_REPLACEABLE_LIST = new [] | |||
{ | |||
"model_dir", | |||
"tf_random_seed", | |||
"save_summary_steps", | |||
"save_checkpoints_steps", | |||
"save_checkpoints_secs", | |||
"session_config", | |||
"keep_checkpoint_max", | |||
"keep_checkpoint_every_n_hours", | |||
"log_step_count_steps", | |||
"train_distribute", | |||
"device_fn", | |||
"protocol", | |||
"eval_distribute", | |||
"experimental_distribute", | |||
"experimental_max_worker_delay_secs", | |||
"session_creation_timeout_secs" | |||
}; | |||
#region const values | |||
private const string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; | |||
private const string _TF_CONFIG_ENV = "TF_CONFIG"; | |||
private const string _TASK_ENV_KEY = "task"; | |||
private const string _TASK_TYPE_KEY = "type"; | |||
private const string _TASK_ID_KEY = "index"; | |||
private const string _CLUSTER_KEY = "cluster"; | |||
private const string _SERVICE_KEY = "service"; | |||
private const string _SESSION_MASTER_KEY = "session_master"; | |||
private const string _EVAL_SESSION_MASTER_KEY = "eval_session_master"; | |||
private const string _MODEL_DIR_KEY = "model_dir"; | |||
private const string _LOCAL_MASTER = ""; | |||
private const string _GRPC_SCHEME = "grpc://"; | |||
#endregion | |||
public string model_dir { get; set; } | |||
public ConfigProto session_config { get; set; } | |||
public int? tf_random_seed { get; set; } | |||
public int save_summary_steps { get; set; } = 100; | |||
public int save_checkpoints_steps { get; set; } | |||
public int save_checkpoints_secs { get; set; } = 600; | |||
public int keep_checkpoint_max { get; set; } = 5; | |||
public int keep_checkpoint_every_n_hours { get; set; } = 10000; | |||
public int log_step_count_steps{ get; set; } = 100; | |||
public object train_distribute { get; set; } | |||
public object device_fn { get; set; } | |||
public object protocol { get; set; } | |||
public object eval_distribute { get; set; } | |||
public object experimental_distribute { get; set; } | |||
public object experimental_max_worker_delay_secs { get; set; } | |||
public int session_creation_timeout_secs { get; set; } = 7200; | |||
public object service { get; set; } | |||
public RunConfig() | |||
{ | |||
Initialize(); | |||
} | |||
public RunConfig(string model_dir, | |||
int save_checkpoints_secs) | |||
{ | |||
this.model_dir = model_dir; | |||
this.save_checkpoints_secs = save_checkpoints_secs; | |||
Initialize(); | |||
} | |||
public RunConfig( | |||
string model_dir = null, | |||
int? tf_random_seed = null, | |||
int save_summary_steps=100, | |||
object save_checkpoints_steps = null, // _USE_DEFAULT | |||
object save_checkpoints_secs = null, // _USE_DEFAULT | |||
object session_config = null, | |||
int keep_checkpoint_max = 5, | |||
int keep_checkpoint_every_n_hours = 10000, | |||
int log_step_count_steps = 100, | |||
object train_distribute = null, | |||
object device_fn = null, | |||
object protocol = null, | |||
object eval_distribute = null, | |||
object experimental_distribute = null, | |||
object experimental_max_worker_delay_secs = null, | |||
int session_creation_timeout_secs = 7200) | |||
{ | |||
this.model_dir = model_dir; | |||
Initialize(); | |||
} | |||
private void Initialize() | |||
{ | |||
} | |||
} | |||
} |
@@ -1,22 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Data; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class TrainSpec | |||
{ | |||
int _max_steps; | |||
public int max_steps => _max_steps; | |||
Func<DatasetV1Adapter> _input_fn; | |||
public Func<DatasetV1Adapter> input_fn => _input_fn; | |||
public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||
{ | |||
_max_steps = max_steps; | |||
_input_fn = input_fn; | |||
} | |||
} | |||
} |
@@ -1,17 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class Training | |||
{ | |||
public static void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||
{ | |||
var executor = new _TrainingExecutor<Thyp>(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||
var config = estimator.config; | |||
executor.run(); | |||
} | |||
} | |||
} |
@@ -1,14 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class _Evaluator<Thyp> | |||
{ | |||
public _Evaluator(Estimator<Thyp> estimator, EvalSpec eval_spec, int max_training_steps) | |||
{ | |||
} | |||
} | |||
} |
@@ -1,16 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class _NewCheckpointListenerForEvaluate<Thyp> | |||
{ | |||
_Evaluator<Thyp> _evaluator; | |||
public _NewCheckpointListenerForEvaluate(_Evaluator<Thyp> evaluator, int eval_throttle_secs) | |||
{ | |||
_evaluator = evaluator; | |||
} | |||
} | |||
} |
@@ -1,14 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
public class _SavedModelExporter<Thyp> : Exporter<Thyp> | |||
{ | |||
public override void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path) | |||
{ | |||
} | |||
} | |||
} |
@@ -1,48 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Estimators | |||
{ | |||
/// <summary> | |||
/// The executor to run `Estimator` training and evaluation. | |||
/// </summary> | |||
internal class _TrainingExecutor<Thyp> | |||
{ | |||
Estimator<Thyp> _estimator; | |||
EvalSpec _eval_spec; | |||
TrainSpec _train_spec; | |||
public _TrainingExecutor(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||
{ | |||
_estimator = estimator; | |||
_train_spec = train_spec; | |||
_eval_spec = eval_spec; | |||
} | |||
public void run() | |||
{ | |||
var config = _estimator.config; | |||
Console.WriteLine("Running training and evaluation locally (non-distributed)."); | |||
run_local(); | |||
} | |||
/// <summary> | |||
/// Runs training and evaluation locally (non-distributed). | |||
/// </summary> | |||
private void run_local() | |||
{ | |||
var train_hooks = new Action[0]; | |||
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||
"after every checkpoint. Checkpoint frequency is determined " + | |||
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||
var evaluator = new _Evaluator<Thyp>(_estimator, _eval_spec, _train_spec.max_steps); | |||
var saving_listeners = new _NewCheckpointListenerForEvaluate<Thyp>[0]; | |||
_estimator.train(input_fn: _train_spec.input_fn, | |||
max_steps: _train_spec.max_steps, | |||
hooks: train_hooks, | |||
saving_listeners: saving_listeners); | |||
} | |||
} | |||
} |
@@ -186,11 +186,6 @@ namespace Tensorflow | |||
if (op_def == null) | |||
op_def = g.GetOpDef(node_def.Op); | |||
if (node_def.Name.Equals("learn_rate/cond/pred_id")) | |||
{ | |||
} | |||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||
_is_stateful = op_def.IsStateful; | |||
@@ -42,8 +42,14 @@ https://tensorflownet.readthedocs.io</Description> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<Compile Remove="Distribute\**" /> | |||
<Compile Remove="Models\**" /> | |||
<Compile Remove="runtimes\**" /> | |||
<EmbeddedResource Remove="Distribute\**" /> | |||
<EmbeddedResource Remove="Models\**" /> | |||
<EmbeddedResource Remove="runtimes\**" /> | |||
<None Remove="Distribute\**" /> | |||
<None Remove="Models\**" /> | |||
<None Remove="runtimes\**" /> | |||
<None Include="..\..\LICENSE"> | |||
<Pack>True</Pack> | |||
@@ -61,8 +67,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<Folder Include="Distribute\" /> | |||
<Folder Include="Estimators\" /> | |||
<Folder Include="Keras\Initializers\" /> | |||
<Folder Include="Models\" /> | |||
</ItemGroup> | |||
</Project> |
@@ -1,64 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Estimators; | |||
namespace TensorFlowNET.UnitTest.Estimators | |||
{ | |||
/// <summary> | |||
/// estimator/tensorflow_estimator/python/estimator/run_config_test.py | |||
/// </summary> | |||
[TestClass] | |||
public class RunConfigTest | |||
{ | |||
private static readonly string _TEST_DIR = "test_dir"; | |||
private static readonly string _MASTER = "master_"; | |||
private static readonly string _NOT_SUPPORTED_REPLACE_PROPERTY_MSG = "Replacing .*is not supported"; | |||
private static readonly string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; | |||
private static readonly string _MODEL_DIR_ERR = "model_dir should be non-empty"; | |||
private static readonly string _MODEL_DIR_TF_CONFIG_ERR = "model_dir in TF_CONFIG should be non-empty"; | |||
private static readonly string _MODEL_DIR_MISMATCH_ERR = "`model_dir` provided in RunConfig construct, if set, must have the same value as the model_dir in TF_CONFIG. "; | |||
private static readonly string _SAVE_SUMMARY_STEPS_ERR = "save_summary_steps should be >= 0"; | |||
private static readonly string _SAVE_CKPT_STEPS_ERR = "save_checkpoints_steps should be >= 0"; | |||
private static readonly string _SAVE_CKPT_SECS_ERR = "save_checkpoints_secs should be >= 0"; | |||
private static readonly string _SESSION_CONFIG_ERR = "session_config must be instance of ConfigProto"; | |||
private static readonly string _KEEP_CKPT_MAX_ERR = "keep_checkpoint_max should be >= 0"; | |||
private static readonly string _KEEP_CKPT_HOURS_ERR = "keep_checkpoint_every_n_hours should be > 0"; | |||
private static readonly string _TF_RANDOM_SEED_ERR = "tf_random_seed must be integer"; | |||
private static readonly string _DEVICE_FN_ERR = "device_fn must be callable with exactly one argument \"op\"."; | |||
private static readonly string _ONE_CHIEF_ERR = "The \"cluster\" in TF_CONFIG must have only one \"chief\" node."; | |||
private static readonly string _ONE_MASTER_ERR = "The \"cluster\" in TF_CONFIG must have only one \"master\" node."; | |||
private static readonly string _MISSING_CHIEF_ERR = "If \"cluster\" is set .* it must have one \"chief\" node"; | |||
private static readonly string _MISSING_TASK_TYPE_ERR = "If \"cluster\" is set .* task type must be set"; | |||
private static readonly string _MISSING_TASK_ID_ERR = "If \"cluster\" is set .* task index must be set"; | |||
private static readonly string _INVALID_TASK_INDEX_ERR = "is not a valid task_id"; | |||
private static readonly string _NEGATIVE_TASK_INDEX_ERR = "Task index must be non-negative number."; | |||
private static readonly string _INVALID_TASK_TYPE_ERR = "is not a valid task_type"; | |||
private static readonly string _INVALID_TASK_TYPE_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task type must be WORKER."; | |||
private static readonly string _INVALID_TASK_INDEX_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task index must be 0."; | |||
private static readonly string _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, task_type `evaluator` is not supported."; | |||
private static readonly string _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, job `chief` is not supported."; | |||
private static readonly string _INVALID_SERVICE_TYPE_ERR = "If \"service\" is set in TF_CONFIG, it must be a dict. Given"; | |||
private static readonly string _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = "experimental_max_worker_delay_secs must be an integer if set."; | |||
private static readonly string _SESSION_CREATION_TIMEOUT_SECS_ERR = "session_creation_timeout_secs should be > 0"; | |||
[TestMethod] | |||
public void test_default_property_values() | |||
{ | |||
var config = new RunConfig(); | |||
Assert.IsNull(config.model_dir); | |||
Assert.IsNull(config.session_config); | |||
Assert.IsNull(config.tf_random_seed); | |||
Assert.AreEqual(100, config.save_summary_steps); | |||
Assert.AreEqual(600, config.save_checkpoints_secs); | |||
Assert.AreEqual(5, config.keep_checkpoint_max); | |||
Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours); | |||
Assert.IsNull(config.service); | |||
Assert.IsNull(config.device_fn); | |||
Assert.IsNull(config.experimental_max_worker_delay_secs); | |||
Assert.AreEqual(7200, config.session_creation_timeout_secs); | |||
} | |||
} | |||
} |