From 5f595eb5eb8fbfc33f5b37aabe18d2b58f5a7292 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Wed, 19 Dec 2018 17:44:37 -0600 Subject: [PATCH] feed_dict_tensor in BaseSession._run --- src/TensorFlowNET.Core/BaseSession.cs | 21 ++++++++++++++++++++- src/TensorFlowNET.Core/Tensor.cs | 2 ++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/BaseSession.cs b/src/TensorFlowNET.Core/BaseSession.cs index c2d4565e..4e7c983f 100644 --- a/src/TensorFlowNET.Core/BaseSession.cs +++ b/src/TensorFlowNET.Core/BaseSession.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -44,6 +45,24 @@ namespace Tensorflow private unsafe byte[] _run(Tensor fetches, Dictionary feed_dict = null) { + var feed_dict_tensor = new Dictionary(); + + if (feed_dict != null) + { + NDArray np_val = null; + foreach (var feed in feed_dict) + { + switch (feed.Value) + { + case float value: + np_val = np.asarray(value); + break; + } + + feed_dict_tensor[feed.Key] = np_val; + } + } + var status = new Status(); c_api.TF_SessionRun(_session, diff --git a/src/TensorFlowNET.Core/Tensor.cs b/src/TensorFlowNET.Core/Tensor.cs index 7c01cf71..e9737d5c 100644 --- a/src/TensorFlowNET.Core/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensor.cs @@ -13,6 +13,8 @@ namespace Tensorflow private DataType _dtype; public DataType dtype => _dtype; + public string name; + public Tensor(Operation op, int value_index, DataType dtype) { _op = op;