@@ -1,61 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Estimators; | |||||
using Tensorflow.Data; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public Estimator_Internal estimator { get; } = new Estimator_Internal(); | |||||
public class Estimator_Internal | |||||
{ | |||||
public Experimental experimental { get; } = new Experimental(); | |||||
public Estimator<Thyp> Estimator<Thyp>(Func<IEstimatorInputs, EstimatorSpec> model_fn, | |||||
string model_dir = null, | |||||
RunConfig config = null, | |||||
Thyp hyperParams = default) | |||||
=> new Estimator<Thyp>(model_fn: model_fn, model_dir: model_dir, config: config, hyperParams: hyperParams); | |||||
public RunConfig RunConfig(string model_dir = null, int save_checkpoints_secs = 180) | |||||
=> new RunConfig(model_dir: model_dir, save_checkpoints_secs: save_checkpoints_secs); | |||||
public void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
=> Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||||
public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
=> new TrainSpec(input_fn: input_fn, max_steps: max_steps); | |||||
/// <summary> | |||||
/// Create an `Exporter` to use with `tf.estimator.EvalSpec`. | |||||
/// </summary> | |||||
/// <param name="name"></param> | |||||
/// <param name="serving_input_receiver_fn"></param> | |||||
/// <param name="as_text"></param> | |||||
/// <returns></returns> | |||||
public FinalExporter FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||||
=> new FinalExporter(name: name, serving_input_receiver_fn: serving_input_receiver_fn, | |||||
as_text: as_text); | |||||
public EvalSpec EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||||
=> new EvalSpec(name: name, input_fn: input_fn, exporters: exporters); | |||||
} | |||||
} | |||||
} |
@@ -1,147 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.IO; | |||||
using System.Text; | |||||
using Tensorflow.Data; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
/// <summary> | |||||
/// Estimator class to train and evaluate TensorFlow models. | |||||
/// </summary> | |||||
public class Estimator<Thyp> : IObjectLife | |||||
{ | |||||
RunConfig _config; | |||||
public RunConfig config => _config; | |||||
ConfigProto _session_config; | |||||
public ConfigProto session_config => _session_config; | |||||
Func<IEstimatorInputs, EstimatorSpec> _model_fn; | |||||
Thyp _hyperParams; | |||||
public Estimator(Func<IEstimatorInputs, EstimatorSpec> model_fn, | |||||
string model_dir, | |||||
RunConfig config, | |||||
Thyp hyperParams) | |||||
{ | |||||
_config = config; | |||||
_config.model_dir = config.model_dir ?? model_dir; | |||||
_session_config = config.session_config; | |||||
_model_fn = model_fn; | |||||
_hyperParams = hyperParams; | |||||
} | |||||
public Estimator<Thyp> train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null, | |||||
_NewCheckpointListenerForEvaluate<Thyp>[] saving_listeners = null) | |||||
{ | |||||
if(max_steps > 0) | |||||
{ | |||||
var start_step = _load_global_step_from_checkpoint_dir(_config.model_dir); | |||||
if (max_steps <= start_step) | |||||
{ | |||||
Console.WriteLine("Skipping training since max_steps has already saved."); | |||||
return this; | |||||
} | |||||
} | |||||
var loss = _train_model(input_fn); | |||||
print($"Loss for final step: {loss}."); | |||||
return this; | |||||
} | |||||
private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||||
{ | |||||
// var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||||
// should use NewCheckpointReader (not implemented) | |||||
var cp = tf.train.get_checkpoint_state(checkpoint_dir); | |||||
return cp.AllModelCheckpointPaths.Count - 1; | |||||
} | |||||
private Tensor _train_model(Func<DatasetV1Adapter> input_fn) | |||||
{ | |||||
return _train_model_default(input_fn); | |||||
} | |||||
private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn) | |||||
{ | |||||
using (var g = tf.Graph().as_default()) | |||||
{ | |||||
var global_step_tensor = _create_and_assert_global_step(g); | |||||
// Skip creating a read variable if _create_and_assert_global_step | |||||
// returns None (e.g. tf.contrib.estimator.SavedModelEstimator). | |||||
if (global_step_tensor != null) | |||||
TrainingUtil._get_or_create_global_step_read(g); | |||||
var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
{ | |||||
var result = _call_input_fn(input_fn, mode); | |||||
return EstimatorUtil.parse_input_fn_result(result); | |||||
} | |||||
/// <summary> | |||||
/// Calls the input function. | |||||
/// </summary> | |||||
/// <param name="input_fn"></param> | |||||
/// <param name="mode"></param> | |||||
private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
{ | |||||
return input_fn(); | |||||
} | |||||
private Tensor _create_and_assert_global_step(Graph graph) | |||||
{ | |||||
var step = _create_global_step(graph); | |||||
Debug.Assert(step == tf.train.get_global_step(graph)); | |||||
Debug.Assert(step.dtype.is_integer()); | |||||
return step; | |||||
} | |||||
private RefVariable _create_global_step(Graph graph) | |||||
{ | |||||
return tf.train.create_global_step(graph); | |||||
} | |||||
public string eval_dir(string name = null) | |||||
{ | |||||
return Path.Combine(config.model_dir, string.IsNullOrEmpty(name) ? "eval" : $"eval_" + name); | |||||
} | |||||
public void __init__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __del__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,16 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class EstimatorSpec | |||||
{ | |||||
public EstimatorSpec(Operation train_op) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,15 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Data; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class EstimatorUtil | |||||
{ | |||||
public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -1,16 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class EvalSpec | |||||
{ | |||||
string _name; | |||||
public EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||||
{ | |||||
_name = name; | |||||
} | |||||
} | |||||
} |
@@ -1,61 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class Experimental | |||||
{ | |||||
/// <summary> | |||||
/// Creates hook to stop if metric does not increase within given max steps. | |||||
/// </summary> | |||||
/// <typeparam name="Thyp">type of hyper parameters</typeparam> | |||||
/// <param name="estimator"></param> | |||||
/// <param name="metric_name"></param> | |||||
/// <param name="max_steps_without_increase"></param> | |||||
/// <param name="eval_dir"></param> | |||||
/// <param name="min_steps"></param> | |||||
/// <param name="run_every_secs"></param> | |||||
/// <param name="run_every_steps"></param> | |||||
/// <returns></returns> | |||||
public object stop_if_no_increase_hook<Thyp>(Estimator<Thyp> estimator, | |||||
string metric_name, | |||||
int max_steps_without_increase, | |||||
string eval_dir = null, | |||||
int min_steps = 0, | |||||
int run_every_secs = 60, | |||||
int run_every_steps = 0) | |||||
=> _stop_if_no_metric_improvement_hook(estimator: estimator, | |||||
metric_name: metric_name, | |||||
max_steps_without_increase: max_steps_without_increase, | |||||
eval_dir: eval_dir, | |||||
min_steps: min_steps, | |||||
run_every_secs: run_every_secs, | |||||
run_every_steps: run_every_steps); | |||||
private object _stop_if_no_metric_improvement_hook<Thyp>(Estimator<Thyp> estimator, | |||||
string metric_name, | |||||
int max_steps_without_increase, | |||||
string eval_dir = null, | |||||
int min_steps = 0, | |||||
int run_every_secs = 60, | |||||
int run_every_steps = 0) | |||||
{ | |||||
eval_dir = eval_dir ?? estimator.eval_dir(); | |||||
// var is_lhs_better = higher_is_better ? operator.gt: operator.lt; | |||||
Func<bool> stop_if_no_metric_improvement_fn = () => | |||||
{ | |||||
return false; | |||||
}; | |||||
return make_early_stopping_hook(); | |||||
} | |||||
public object make_early_stopping_hook() | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -1,11 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public abstract class Exporter<Thyp> | |||||
{ | |||||
public abstract void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path); | |||||
} | |||||
} |
@@ -1,14 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class FinalExporter | |||||
{ | |||||
public FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,89 +0,0 @@ | |||||
using System.IO; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class HyperParams | |||||
{ | |||||
/// <summary> | |||||
/// root dir | |||||
/// </summary> | |||||
public string data_root_dir { get; set; } | |||||
/// <summary> | |||||
/// results dir | |||||
/// </summary> | |||||
public string result_dir { get; set; } = "results"; | |||||
/// <summary> | |||||
/// model dir | |||||
/// </summary> | |||||
public string model_dir { get; set; } = "model"; | |||||
public string eval_dir { get; set; } = "eval"; | |||||
public string test_dir { get; set; } = "test"; | |||||
public int dim { get; set; } = 300; | |||||
public float dropout { get; set; } = 0.5f; | |||||
public int num_oov_buckets { get; set; } = 1; | |||||
public int epochs { get; set; } = 25; | |||||
public int epoch_no_imprv { get; set; } = 3; | |||||
public int batch_size { get; set; } = 20; | |||||
public int buffer { get; set; } = 15000; | |||||
public int lstm_size { get; set; } = 100; | |||||
public string lr_method { get; set; } = "adam"; | |||||
public float lr { get; set; } = 0.001f; | |||||
public float lr_decay { get; set; } = 0.9f; | |||||
/// <summary> | |||||
/// lstm on chars | |||||
/// </summary> | |||||
public int hidden_size_char { get; set; } = 100; | |||||
/// <summary> | |||||
/// lstm on word embeddings | |||||
/// </summary> | |||||
public int hidden_size_lstm { get; set; } = 300; | |||||
/// <summary> | |||||
/// is clipping | |||||
/// </summary> | |||||
public bool clip { get; set; } = false; | |||||
public string filepath_dev { get; set; } | |||||
public string filepath_test { get; set; } | |||||
public string filepath_train { get; set; } | |||||
public string filepath_words { get; set; } | |||||
public string filepath_chars { get; set; } | |||||
public string filepath_tags { get; set; } | |||||
public string filepath_glove { get; set; } | |||||
public HyperParams(string dataDir) | |||||
{ | |||||
data_root_dir = dataDir; | |||||
if (string.IsNullOrEmpty(data_root_dir)) | |||||
throw new ValueError("Please specifiy the root data directory"); | |||||
if (!Directory.Exists(data_root_dir)) | |||||
Directory.CreateDirectory(data_root_dir); | |||||
result_dir = Path.Combine(data_root_dir, result_dir); | |||||
if (!Directory.Exists(result_dir)) | |||||
Directory.CreateDirectory(result_dir); | |||||
model_dir = Path.Combine(result_dir, model_dir); | |||||
if (!Directory.Exists(model_dir)) | |||||
Directory.CreateDirectory(model_dir); | |||||
test_dir = Path.Combine(result_dir, test_dir); | |||||
if (!Directory.Exists(test_dir)) | |||||
Directory.CreateDirectory(test_dir); | |||||
eval_dir = Path.Combine(result_dir, eval_dir); | |||||
if (!Directory.Exists(eval_dir)) | |||||
Directory.CreateDirectory(eval_dir); | |||||
} | |||||
} | |||||
} |
@@ -1,12 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public interface IEstimatorInputs | |||||
{ | |||||
} | |||||
} |
@@ -1,7 +0,0 @@ | |||||
# TensorFlow Estimator | |||||
TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. | |||||
https://github.com/tensorflow/estimator |
@@ -1,103 +0,0 @@ | |||||
using System; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class RunConfig | |||||
{ | |||||
// A list of the property names in RunConfig that the user is allowed to change. | |||||
private static readonly string[] _DEFAULT_REPLACEABLE_LIST = new [] | |||||
{ | |||||
"model_dir", | |||||
"tf_random_seed", | |||||
"save_summary_steps", | |||||
"save_checkpoints_steps", | |||||
"save_checkpoints_secs", | |||||
"session_config", | |||||
"keep_checkpoint_max", | |||||
"keep_checkpoint_every_n_hours", | |||||
"log_step_count_steps", | |||||
"train_distribute", | |||||
"device_fn", | |||||
"protocol", | |||||
"eval_distribute", | |||||
"experimental_distribute", | |||||
"experimental_max_worker_delay_secs", | |||||
"session_creation_timeout_secs" | |||||
}; | |||||
#region const values | |||||
private const string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; | |||||
private const string _TF_CONFIG_ENV = "TF_CONFIG"; | |||||
private const string _TASK_ENV_KEY = "task"; | |||||
private const string _TASK_TYPE_KEY = "type"; | |||||
private const string _TASK_ID_KEY = "index"; | |||||
private const string _CLUSTER_KEY = "cluster"; | |||||
private const string _SERVICE_KEY = "service"; | |||||
private const string _SESSION_MASTER_KEY = "session_master"; | |||||
private const string _EVAL_SESSION_MASTER_KEY = "eval_session_master"; | |||||
private const string _MODEL_DIR_KEY = "model_dir"; | |||||
private const string _LOCAL_MASTER = ""; | |||||
private const string _GRPC_SCHEME = "grpc://"; | |||||
#endregion | |||||
public string model_dir { get; set; } | |||||
public ConfigProto session_config { get; set; } | |||||
public int? tf_random_seed { get; set; } | |||||
public int save_summary_steps { get; set; } = 100; | |||||
public int save_checkpoints_steps { get; set; } | |||||
public int save_checkpoints_secs { get; set; } = 600; | |||||
public int keep_checkpoint_max { get; set; } = 5; | |||||
public int keep_checkpoint_every_n_hours { get; set; } = 10000; | |||||
public int log_step_count_steps{ get; set; } = 100; | |||||
public object train_distribute { get; set; } | |||||
public object device_fn { get; set; } | |||||
public object protocol { get; set; } | |||||
public object eval_distribute { get; set; } | |||||
public object experimental_distribute { get; set; } | |||||
public object experimental_max_worker_delay_secs { get; set; } | |||||
public int session_creation_timeout_secs { get; set; } = 7200; | |||||
public object service { get; set; } | |||||
public RunConfig() | |||||
{ | |||||
Initialize(); | |||||
} | |||||
public RunConfig(string model_dir, | |||||
int save_checkpoints_secs) | |||||
{ | |||||
this.model_dir = model_dir; | |||||
this.save_checkpoints_secs = save_checkpoints_secs; | |||||
Initialize(); | |||||
} | |||||
public RunConfig( | |||||
string model_dir = null, | |||||
int? tf_random_seed = null, | |||||
int save_summary_steps=100, | |||||
object save_checkpoints_steps = null, // _USE_DEFAULT | |||||
object save_checkpoints_secs = null, // _USE_DEFAULT | |||||
object session_config = null, | |||||
int keep_checkpoint_max = 5, | |||||
int keep_checkpoint_every_n_hours = 10000, | |||||
int log_step_count_steps = 100, | |||||
object train_distribute = null, | |||||
object device_fn = null, | |||||
object protocol = null, | |||||
object eval_distribute = null, | |||||
object experimental_distribute = null, | |||||
object experimental_max_worker_delay_secs = null, | |||||
int session_creation_timeout_secs = 7200) | |||||
{ | |||||
this.model_dir = model_dir; | |||||
Initialize(); | |||||
} | |||||
private void Initialize() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,22 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Data; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class TrainSpec | |||||
{ | |||||
int _max_steps; | |||||
public int max_steps => _max_steps; | |||||
Func<DatasetV1Adapter> _input_fn; | |||||
public Func<DatasetV1Adapter> input_fn => _input_fn; | |||||
public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
{ | |||||
_max_steps = max_steps; | |||||
_input_fn = input_fn; | |||||
} | |||||
} | |||||
} |
@@ -1,17 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class Training | |||||
{ | |||||
public static void train_and_evaluate<Thyp>(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
{ | |||||
var executor = new _TrainingExecutor<Thyp>(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||||
var config = estimator.config; | |||||
executor.run(); | |||||
} | |||||
} | |||||
} |
@@ -1,14 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class _Evaluator<Thyp> | |||||
{ | |||||
public _Evaluator(Estimator<Thyp> estimator, EvalSpec eval_spec, int max_training_steps) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,16 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class _NewCheckpointListenerForEvaluate<Thyp> | |||||
{ | |||||
_Evaluator<Thyp> _evaluator; | |||||
public _NewCheckpointListenerForEvaluate(_Evaluator<Thyp> evaluator, int eval_throttle_secs) | |||||
{ | |||||
_evaluator = evaluator; | |||||
} | |||||
} | |||||
} |
@@ -1,14 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
public class _SavedModelExporter<Thyp> : Exporter<Thyp> | |||||
{ | |||||
public override void export(Estimator<Thyp> estimator, string export_path, string checkpoint_path) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,48 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Estimators | |||||
{ | |||||
/// <summary> | |||||
/// The executor to run `Estimator` training and evaluation. | |||||
/// </summary> | |||||
internal class _TrainingExecutor<Thyp> | |||||
{ | |||||
Estimator<Thyp> _estimator; | |||||
EvalSpec _eval_spec; | |||||
TrainSpec _train_spec; | |||||
public _TrainingExecutor(Estimator<Thyp> estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
{ | |||||
_estimator = estimator; | |||||
_train_spec = train_spec; | |||||
_eval_spec = eval_spec; | |||||
} | |||||
public void run() | |||||
{ | |||||
var config = _estimator.config; | |||||
Console.WriteLine("Running training and evaluation locally (non-distributed)."); | |||||
run_local(); | |||||
} | |||||
/// <summary> | |||||
/// Runs training and evaluation locally (non-distributed). | |||||
/// </summary> | |||||
private void run_local() | |||||
{ | |||||
var train_hooks = new Action[0]; | |||||
Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||||
"after every checkpoint. Checkpoint frequency is determined " + | |||||
$"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||||
$"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||||
var evaluator = new _Evaluator<Thyp>(_estimator, _eval_spec, _train_spec.max_steps); | |||||
var saving_listeners = new _NewCheckpointListenerForEvaluate<Thyp>[0]; | |||||
_estimator.train(input_fn: _train_spec.input_fn, | |||||
max_steps: _train_spec.max_steps, | |||||
hooks: train_hooks, | |||||
saving_listeners: saving_listeners); | |||||
} | |||||
} | |||||
} |
@@ -186,11 +186,6 @@ namespace Tensorflow | |||||
if (op_def == null) | if (op_def == null) | ||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
if (node_def.Name.Equals("learn_rate/cond/pred_id")) | |||||
{ | |||||
} | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
_is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
@@ -42,8 +42,14 @@ https://tensorflownet.readthedocs.io</Description> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<Compile Remove="Distribute\**" /> | |||||
<Compile Remove="Models\**" /> | |||||
<Compile Remove="runtimes\**" /> | <Compile Remove="runtimes\**" /> | ||||
<EmbeddedResource Remove="Distribute\**" /> | |||||
<EmbeddedResource Remove="Models\**" /> | |||||
<EmbeddedResource Remove="runtimes\**" /> | <EmbeddedResource Remove="runtimes\**" /> | ||||
<None Remove="Distribute\**" /> | |||||
<None Remove="Models\**" /> | |||||
<None Remove="runtimes\**" /> | <None Remove="runtimes\**" /> | ||||
<None Include="..\..\LICENSE"> | <None Include="..\..\LICENSE"> | ||||
<Pack>True</Pack> | <Pack>True</Pack> | ||||
@@ -61,8 +67,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<Folder Include="Distribute\" /> | |||||
<Folder Include="Estimators\" /> | |||||
<Folder Include="Keras\Initializers\" /> | <Folder Include="Keras\Initializers\" /> | ||||
<Folder Include="Models\" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
</Project> | </Project> |
@@ -1,64 +0,0 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using Tensorflow; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Estimators; | |||||
namespace TensorFlowNET.UnitTest.Estimators | |||||
{ | |||||
/// <summary> | |||||
/// estimator/tensorflow_estimator/python/estimator/run_config_test.py | |||||
/// </summary> | |||||
[TestClass] | |||||
public class RunConfigTest | |||||
{ | |||||
private static readonly string _TEST_DIR = "test_dir"; | |||||
private static readonly string _MASTER = "master_"; | |||||
private static readonly string _NOT_SUPPORTED_REPLACE_PROPERTY_MSG = "Replacing .*is not supported"; | |||||
private static readonly string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; | |||||
private static readonly string _MODEL_DIR_ERR = "model_dir should be non-empty"; | |||||
private static readonly string _MODEL_DIR_TF_CONFIG_ERR = "model_dir in TF_CONFIG should be non-empty"; | |||||
private static readonly string _MODEL_DIR_MISMATCH_ERR = "`model_dir` provided in RunConfig construct, if set, must have the same value as the model_dir in TF_CONFIG. "; | |||||
private static readonly string _SAVE_SUMMARY_STEPS_ERR = "save_summary_steps should be >= 0"; | |||||
private static readonly string _SAVE_CKPT_STEPS_ERR = "save_checkpoints_steps should be >= 0"; | |||||
private static readonly string _SAVE_CKPT_SECS_ERR = "save_checkpoints_secs should be >= 0"; | |||||
private static readonly string _SESSION_CONFIG_ERR = "session_config must be instance of ConfigProto"; | |||||
private static readonly string _KEEP_CKPT_MAX_ERR = "keep_checkpoint_max should be >= 0"; | |||||
private static readonly string _KEEP_CKPT_HOURS_ERR = "keep_checkpoint_every_n_hours should be > 0"; | |||||
private static readonly string _TF_RANDOM_SEED_ERR = "tf_random_seed must be integer"; | |||||
private static readonly string _DEVICE_FN_ERR = "device_fn must be callable with exactly one argument \"op\"."; | |||||
private static readonly string _ONE_CHIEF_ERR = "The \"cluster\" in TF_CONFIG must have only one \"chief\" node."; | |||||
private static readonly string _ONE_MASTER_ERR = "The \"cluster\" in TF_CONFIG must have only one \"master\" node."; | |||||
private static readonly string _MISSING_CHIEF_ERR = "If \"cluster\" is set .* it must have one \"chief\" node"; | |||||
private static readonly string _MISSING_TASK_TYPE_ERR = "If \"cluster\" is set .* task type must be set"; | |||||
private static readonly string _MISSING_TASK_ID_ERR = "If \"cluster\" is set .* task index must be set"; | |||||
private static readonly string _INVALID_TASK_INDEX_ERR = "is not a valid task_id"; | |||||
private static readonly string _NEGATIVE_TASK_INDEX_ERR = "Task index must be non-negative number."; | |||||
private static readonly string _INVALID_TASK_TYPE_ERR = "is not a valid task_type"; | |||||
private static readonly string _INVALID_TASK_TYPE_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task type must be WORKER."; | |||||
private static readonly string _INVALID_TASK_INDEX_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task index must be 0."; | |||||
private static readonly string _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, task_type `evaluator` is not supported."; | |||||
private static readonly string _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, job `chief` is not supported."; | |||||
private static readonly string _INVALID_SERVICE_TYPE_ERR = "If \"service\" is set in TF_CONFIG, it must be a dict. Given"; | |||||
private static readonly string _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = "experimental_max_worker_delay_secs must be an integer if set."; | |||||
private static readonly string _SESSION_CREATION_TIMEOUT_SECS_ERR = "session_creation_timeout_secs should be > 0"; | |||||
[TestMethod] | |||||
public void test_default_property_values() | |||||
{ | |||||
var config = new RunConfig(); | |||||
Assert.IsNull(config.model_dir); | |||||
Assert.IsNull(config.session_config); | |||||
Assert.IsNull(config.tf_random_seed); | |||||
Assert.AreEqual(100, config.save_summary_steps); | |||||
Assert.AreEqual(600, config.save_checkpoints_secs); | |||||
Assert.AreEqual(5, config.keep_checkpoint_max); | |||||
Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours); | |||||
Assert.IsNull(config.service); | |||||
Assert.IsNull(config.device_fn); | |||||
Assert.IsNull(config.experimental_max_worker_delay_secs); | |||||
Assert.AreEqual(7200, config.session_creation_timeout_secs); | |||||
} | |||||
} | |||||
} |