From de3ecc8556b2b1c5acb51532fece6f20a4f46a6a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 12 Sep 2019 06:31:17 -0500 Subject: [PATCH] create_estimator_and_inputs --- src/TensorFlowNET.Core/APIs/tf.estimator.cs | 4 ++-- src/TensorFlowNET.Core/Estimators/Estimator.cs | 2 +- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 6 +++--- src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs | 8 +++++--- test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs | 1 - test/TensorFlowNET.UnitTest/MultithreadingTests.cs | 3 +++ 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.estimator.cs b/src/TensorFlowNET.Core/APIs/tf.estimator.cs index 8314f202..9789e11f 100644 --- a/src/TensorFlowNET.Core/APIs/tf.estimator.cs +++ b/src/TensorFlowNET.Core/APIs/tf.estimator.cs @@ -26,8 +26,8 @@ namespace Tensorflow public class Estimator_Internal { - public Estimator Estimator(RunConfig config) - => new Estimator(config: config); + public Estimator Estimator(Action model_fn, RunConfig config) + => new Estimator(model_fn: model_fn, config: config); public RunConfig RunConfig(string model_dir) => new RunConfig(model_dir: model_dir); diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs index 56186069..91596d7d 100644 --- a/src/TensorFlowNET.Core/Estimators/Estimator.cs +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Estimators string _model_dir; - public Estimator(RunConfig config) + public Estimator(Action model_fn, RunConfig config) { _config = config; _model_dir = _config.model_dir; diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index d2d5d1b4..3594649d 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.3 + 0.11.4 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.3.0 + 0.11.4.0 Changes since v0.10.0: 1. Upgrade NumSharp to v0.20. 2. Add DisposableObject class to manage object lifetime. @@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io 9. MultiThread is safe. 10. Support n-dim indexing for tensor. 7.3 - 0.11.3.0 + 0.11.4.0 LICENSE true true diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs index ee70abc3..33c5c236 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs @@ -18,18 +18,20 @@ namespace Tensorflow.Models.ObjectDetection int sample_1_of_n_eval_examples = 0, int sample_1_of_n_eval_on_train_examples = 1) { - var estimator = tf.estimator.Estimator(config: run_config); - var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); var eval_input_configs = config.EvalInputReader; var eval_input_fns = new Action[eval_input_configs.Count]; + var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray(); + Action model_fn = () => { }; + var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config); return new TrainAndEvalDict { estimator = estimator, train_steps = train_steps, - eval_input_fns = eval_input_fns + eval_input_fns = eval_input_fns, + eval_input_names = eval_input_names }; } diff --git a/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs index 6a4b1806..66cb48e3 100644 --- a/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs +++ b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs @@ -53,7 +53,6 @@ namespace TensorFlowNET.UnitTest.Estimators Assert.IsNull(config.tf_random_seed); Assert.AreEqual(100, config.save_summary_steps); Assert.AreEqual(600, config.save_checkpoints_secs); - Assert.IsNull(config.save_checkpoints_steps); Assert.AreEqual(5, config.keep_checkpoint_max); Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours); Assert.IsNull(config.service); diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index 1e4d829c..f0a79ed6 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -289,6 +289,9 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void TF_GraphOperationByName_FromModel() { + if (!Directory.Exists(modelPath)) + return; + MultiThreadedUnitTestExecuter.Run(8, Core); //the core method