From eb99b2a4a7ff7362da4961a5b8ac69ce7f048081 Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Wed, 11 Sep 2019 17:34:54 -0700 Subject: [PATCH] tried to implement RunConfig and RunConfigTest --- .../Estimators/RunConfig.cs | 55 +++++++++++++++- .../Estimators/RunConfigTest.cs | 65 +++++++++++++++++++ 2 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs diff --git a/src/TensorFlowNET.Core/Estimators/RunConfig.cs b/src/TensorFlowNET.Core/Estimators/RunConfig.cs index f1c83e87..e13d25aa 100644 --- a/src/TensorFlowNET.Core/Estimators/RunConfig.cs +++ b/src/TensorFlowNET.Core/Estimators/RunConfig.cs @@ -44,10 +44,9 @@ namespace Tensorflow.Estimators #endregion private static readonly object _USE_DEFAULT = new object(); - public string model_dir { get; set; } public ConfigProto session_config { get; set; } - public int tf_random_seed { get; set; } + public int? tf_random_seed { get; set; } public int save_summary_steps { get; set; } = 100; public object save_checkpoints_steps { get; set; } = _USE_DEFAULT; public object save_checkpoints_secs { get; set; } = _USE_DEFAULT; @@ -61,10 +60,60 @@ namespace Tensorflow.Estimators public object experimental_distribute { get; set; } public object experimental_max_worker_delay_secs { get; set; } public int session_creation_timeout_secs { get; set; } = 7200; + public object service { get; set; } + + public RunConfig() + { + Initialize(); + } public RunConfig(string model_dir) { - this.model_dir = model_dir; + this.model_dir = model_dir; + Initialize(); + } + + public RunConfig( + string model_dir = null, + int? tf_random_seed = null, + int save_summary_steps=100, + object save_checkpoints_steps = null, // _USE_DEFAULT + object save_checkpoints_secs = null, // _USE_DEFAULT + object session_config = null, + int keep_checkpoint_max = 5, + int keep_checkpoint_every_n_hours = 10000, + int log_step_count_steps = 100, + object train_distribute = null, + object device_fn = null, + object protocol = null, + object eval_distribute = null, + object experimental_distribute = null, + object experimental_max_worker_delay_secs = null, + int session_creation_timeout_secs = 7200) + { + this.model_dir = model_dir; + Initialize(); + } + + private void Initialize() + { + if (this.save_checkpoints_steps == _USE_DEFAULT && this.save_checkpoints_secs == _USE_DEFAULT) + { + this.save_checkpoints_steps = null; + this.save_checkpoints_secs = 600; + } + else if (this.save_checkpoints_secs == _USE_DEFAULT) + { + this.save_checkpoints_secs = null; + } + else if (this.save_checkpoints_steps == _USE_DEFAULT) + { + this.save_checkpoints_steps = null; + } + else if (this.save_checkpoints_steps != null && save_checkpoints_secs != null) + { + throw new Exception(_SAVE_CKPT_ERR); + } } } } diff --git a/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs new file mode 100644 index 00000000..6a4b1806 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; +using Tensorflow.Estimators; + +namespace TensorFlowNET.UnitTest.Estimators +{ + /// + /// estimator/tensorflow_estimator/python/estimator/run_config_test.py + /// + [TestClass] + public class RunConfigTest + { + private static readonly string _TEST_DIR = "test_dir"; + private static readonly string _MASTER = "master_"; + private static readonly string _NOT_SUPPORTED_REPLACE_PROPERTY_MSG = "Replacing .*is not supported"; + private static readonly string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; + private static readonly string _MODEL_DIR_ERR = "model_dir should be non-empty"; + private static readonly string _MODEL_DIR_TF_CONFIG_ERR = "model_dir in TF_CONFIG should be non-empty"; + private static readonly string _MODEL_DIR_MISMATCH_ERR = "`model_dir` provided in RunConfig construct, if set, must have the same value as the model_dir in TF_CONFIG. "; + private static readonly string _SAVE_SUMMARY_STEPS_ERR = "save_summary_steps should be >= 0"; + private static readonly string _SAVE_CKPT_STEPS_ERR = "save_checkpoints_steps should be >= 0"; + private static readonly string _SAVE_CKPT_SECS_ERR = "save_checkpoints_secs should be >= 0"; + private static readonly string _SESSION_CONFIG_ERR = "session_config must be instance of ConfigProto"; + private static readonly string _KEEP_CKPT_MAX_ERR = "keep_checkpoint_max should be >= 0"; + private static readonly string _KEEP_CKPT_HOURS_ERR = "keep_checkpoint_every_n_hours should be > 0"; + private static readonly string _TF_RANDOM_SEED_ERR = "tf_random_seed must be integer"; + private static readonly string _DEVICE_FN_ERR = "device_fn must be callable with exactly one argument \"op\"."; + private static readonly string _ONE_CHIEF_ERR = "The \"cluster\" in TF_CONFIG must have only one \"chief\" node."; + private static readonly string _ONE_MASTER_ERR = "The \"cluster\" in TF_CONFIG must have only one \"master\" node."; + private static readonly string _MISSING_CHIEF_ERR = "If \"cluster\" is set .* it must have one \"chief\" node"; + private static readonly string _MISSING_TASK_TYPE_ERR = "If \"cluster\" is set .* task type must be set"; + private static readonly string _MISSING_TASK_ID_ERR = "If \"cluster\" is set .* task index must be set"; + private static readonly string _INVALID_TASK_INDEX_ERR = "is not a valid task_id"; + private static readonly string _NEGATIVE_TASK_INDEX_ERR = "Task index must be non-negative number."; + private static readonly string _INVALID_TASK_TYPE_ERR = "is not a valid task_type"; + private static readonly string _INVALID_TASK_TYPE_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task type must be WORKER."; + private static readonly string _INVALID_TASK_INDEX_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task index must be 0."; + private static readonly string _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, task_type `evaluator` is not supported."; + private static readonly string _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, job `chief` is not supported."; + private static readonly string _INVALID_SERVICE_TYPE_ERR = "If \"service\" is set in TF_CONFIG, it must be a dict. Given"; + private static readonly string _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = "experimental_max_worker_delay_secs must be an integer if set."; + private static readonly string _SESSION_CREATION_TIMEOUT_SECS_ERR = "session_creation_timeout_secs should be > 0"; + + [TestMethod] + public void test_default_property_values() + { + var config = new RunConfig(); + + Assert.IsNull(config.model_dir); + Assert.IsNull(config.session_config); + Assert.IsNull(config.tf_random_seed); + Assert.AreEqual(100, config.save_summary_steps); + Assert.AreEqual(600, config.save_checkpoints_secs); + Assert.IsNull(config.save_checkpoints_steps); + Assert.AreEqual(5, config.keep_checkpoint_max); + Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours); + Assert.IsNull(config.service); + Assert.IsNull(config.device_fn); + Assert.IsNull(config.experimental_max_worker_delay_secs); + Assert.AreEqual(7200, config.session_creation_timeout_secs); + } + } +}