From 9a58ebb60e904a93a6f28390f3244a8f0856e0b2 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Sat, 13 Jul 2019 16:31:43 +0200 Subject: [PATCH] BaseSession: feeddict values are not auto-converted to NDArray any more (was a waste of time and memory) --- .../Sessions/BaseSession.cs | 174 +++++++++--------- 1 file changed, 89 insertions(+), 85 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 66a1952e..34e0a06a 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,17 +1,17 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. ******************************************************************************/ using NumSharp; @@ -19,6 +19,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Numerics; using System.Runtime.InteropServices; using System.Text; @@ -31,18 +32,18 @@ namespace Tensorflow protected bool _closed; protected int _current_version; protected byte[] _target; - protected IntPtr _session; - public Status Status; + protected IntPtr _session; + public Status Status; public Graph graph => _graph; public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) - { + { _graph = g is null ? ops.get_default_graph() : g; _target = UTF8Encoding.UTF8.GetBytes(target); SessionOptions newOpts = null; - if (opts == null) + if (opts == null) newOpts = c_api.TF_NewSessionOptions(); Status = new Status(); @@ -50,7 +51,7 @@ namespace Tensorflow _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); // dispose newOpts - if (opts == null) + if (opts == null) c_api.TF_DeleteSessionOptions(newOpts); Status.Check(true); @@ -63,7 +64,7 @@ namespace Tensorflow public virtual NDArray run(object fetches, Hashtable feed_dict = null) { - var feed_items = feed_dict == null ? new FeedItem[0] : + var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); return _run(fetches, feed_items); } @@ -86,57 +87,8 @@ namespace Tensorflow foreach (var (subfeed, subfeed_val) in feed_fn(feed)) { var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); - - switch (subfeed_val) - { - case IntPtr val: - feed_dict_tensor[subfeed_t] = val; - break; - case NDArray val: - feed_dict_tensor[subfeed_t] = val; - break; - case float val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case double val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case short val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case int val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case long val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case long[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case int[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case string val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case byte[] val: - feed_dict_tensor[subfeed_t] = np.array(val); - break; - case char[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case bool val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case bool[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - default: - Console.WriteLine($"can't handle data type of subfeed_val"); - throw new NotImplementedException("_run subfeed"); - } - + //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed_val; feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); } } @@ -175,26 +127,78 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => + var feeds = feed_dict.Select(x => { if (x.Key is Tensor tensor) { switch (x.Value) { - case IntPtr pointer: - return new KeyValuePair(tensor._as_tf_output(), pointer); - case Tensor t1: - return new KeyValuePair(tensor._as_tf_output(), t1); - case NDArray nd: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(nd, tensor.dtype)); - case int intVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(intVal)); - case float floatVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(floatVal)); - case double doubleVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(doubleVal)); +#if _REGEN + %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case #1[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + % +#else + case sbyte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case sbyte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); +#endif + case bool v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); + case string v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case IntPtr v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Tensor v: + return new KeyValuePair(tensor._as_tf_output(), v); + case NDArray v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); default: - throw new NotImplementedException("feed_dict data type"); + throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); } } throw new NotImplementedException("_do_run.feed_dict");