diff --git a/src/TensorFlowNET.Core/Contrib/Train/HParams.cs b/src/TensorFlowNET.Core/Contrib/Train/HParams.cs new file mode 100644 index 00000000..bd85ad4c --- /dev/null +++ b/src/TensorFlowNET.Core/Contrib/Train/HParams.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Contrib.Train +{ + /// + /// Class to hold a set of hyperparameters as name-value pairs. + /// + public class HParams + { + public bool load_pretrained { get; set; } + + public HParams(bool load_pretrained) + { + this.load_pretrained = load_pretrained; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/SessionRunContext.cs b/src/TensorFlowNET.Core/Train/SessionRunContext.cs new file mode 100644 index 00000000..bb54bbe5 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SessionRunContext.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + public class SessionRunContext + { + public SessionRunContext(Session session) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Train/_MonitoredSession.cs b/src/TensorFlowNET.Core/Train/_MonitoredSession.cs new file mode 100644 index 00000000..e89b1b89 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/_MonitoredSession.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + internal class _MonitoredSession + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs index d8543970..c68b8815 100644 --- a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs +++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs @@ -4,13 +4,18 @@ using System.Text; using static Tensorflow.Binding; using Tensorflow.Estimators; using System.Linq; +using Tensorflow.Contrib.Train; namespace Tensorflow.Models.ObjectDetection { public class ModelLib { public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, - int train_steps = 1) + HParams hparams = null, + string pipeline_config_path = null, + int train_steps = 0, + int sample_1_of_n_eval_examples = 0, + int sample_1_of_n_eval_on_train_examples = 1) { var estimator = tf.estimator.Estimator(config: run_config); diff --git a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs new file mode 100644 index 00000000..c71a7d1e --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection.Utils +{ + public class ConfigUtil + { + public object get_configs_from_pipeline_file(string pipeline_config_path) + { + throw new NotImplementedException(""); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs index 9e25616c..03bbab7f 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection/Main.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Text; using Tensorflow; +using Tensorflow.Contrib.Train; using Tensorflow.Estimators; using Tensorflow.Models.ObjectDetection; using static Tensorflow.Binding; @@ -18,12 +19,23 @@ namespace TensorFlowNET.Examples.ImageProcessing.ObjectDetection ModelLib model_lib = new ModelLib(); + string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model"; + string pipeline_config_path = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model/faster_rcnn_resnet101_voc07.config"; + int num_train_steps = 1; + int sample_1_of_n_eval_examples = 1; + int sample_1_of_n_eval_on_train_examples = 5; + public bool Run() { - string model_dir = "D:/Projects/PythonLab/tf-models/research/object_detection/models/model"; - var config = tf.estimator.RunConfig(model_dir: model_dir); - var train_and_eval_dict = model_lib.create_estimator_and_inputs(run_config: config); + + var train_and_eval_dict = model_lib.create_estimator_and_inputs(run_config: config, + hparams: new HParams(true), + pipeline_config_path: pipeline_config_path, + train_steps: num_train_steps, + sample_1_of_n_eval_examples: sample_1_of_n_eval_examples, + sample_1_of_n_eval_on_train_examples: sample_1_of_n_eval_on_train_examples); + var estimator = train_and_eval_dict.estimator; var train_input_fn = train_and_eval_dict.train_input_fn; var eval_input_fns = train_and_eval_dict.eval_input_fns;