Browse Source

BaseSession: feeddict values are not auto-converted to NDArray any more (was a waste of time and memory)

tags/v0.10
Meinrad Recheis 6 years ago
parent
commit
9a58ebb60e
1 changed files with 89 additions and 85 deletions
  1. +89
    -85
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 89
- 85
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -1,17 +1,17 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
@@ -19,6 +19,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Runtime.InteropServices;
using System.Text;

@@ -31,18 +32,18 @@ namespace Tensorflow
protected bool _closed;
protected int _current_version;
protected byte[] _target;
protected IntPtr _session;
public Status Status;
protected IntPtr _session;
public Status Status;
public Graph graph => _graph;

public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{
{
_graph = g is null ? ops.get_default_graph() : g;

_target = UTF8Encoding.UTF8.GetBytes(target);

SessionOptions newOpts = null;
if (opts == null)
if (opts == null)
newOpts = c_api.TF_NewSessionOptions();

Status = new Status();
@@ -50,7 +51,7 @@ namespace Tensorflow
_session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status);

// dispose newOpts
if (opts == null)
if (opts == null)
c_api.TF_DeleteSessionOptions(newOpts);

Status.Check(true);
@@ -63,7 +64,7 @@ namespace Tensorflow

public virtual NDArray run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
@@ -86,57 +87,8 @@ namespace Tensorflow
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();

switch (subfeed_val)
{
case IntPtr val:
feed_dict_tensor[subfeed_t] = val;
break;
case NDArray val:
feed_dict_tensor[subfeed_t] = val;
break;
case float val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case double val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case short val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case long val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case long[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case string val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case byte[] val:
feed_dict_tensor[subfeed_t] = np.array(val);
break;
case char[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
default:
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");
}
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed_val;
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
@@ -175,26 +127,78 @@ namespace Tensorflow
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x =>
var feeds = feed_dict.Select(x =>
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
case IntPtr pointer:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
case Tensor t1:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
case NDArray nd:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd, tensor.dtype));
case int intVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
case float floatVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
case double doubleVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
#if _REGEN
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types%
case #1 v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case #1[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
%
#else
case sbyte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case sbyte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case byte[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case short[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ushort[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case int[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case uint[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case long[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case ulong[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case float[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case double[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Complex[] v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
#endif
case bool v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL));
case string v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case IntPtr v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
case Tensor v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
case NDArray v:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
default:
throw new NotImplementedException("feed_dict data type");
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}");
}
}
throw new NotImplementedException("_do_run.feed_dict");


Loading…
Cancel
Save