|
|
@@ -1,17 +1,17 @@ |
|
|
|
/*****************************************************************************
|
|
|
|
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
/***************************************************************************** |
|
|
|
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. |
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
you may not use this file except in compliance with the License. |
|
|
|
You may obtain a copy of the License at |
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
See the License for the specific language governing permissions and |
|
|
|
limitations under the License. |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using NumSharp; |
|
|
@@ -19,6 +19,7 @@ using System; |
|
|
|
using System.Collections; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Numerics; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System.Text; |
|
|
|
|
|
|
@@ -31,18 +32,18 @@ namespace Tensorflow |
|
|
|
protected bool _closed; |
|
|
|
protected int _current_version; |
|
|
|
protected byte[] _target; |
|
|
|
protected IntPtr _session;
|
|
|
|
public Status Status;
|
|
|
|
protected IntPtr _session; |
|
|
|
public Status Status; |
|
|
|
public Graph graph => _graph; |
|
|
|
|
|
|
|
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) |
|
|
|
{
|
|
|
|
{ |
|
|
|
_graph = g is null ? ops.get_default_graph() : g; |
|
|
|
|
|
|
|
_target = UTF8Encoding.UTF8.GetBytes(target); |
|
|
|
|
|
|
|
SessionOptions newOpts = null; |
|
|
|
if (opts == null)
|
|
|
|
if (opts == null) |
|
|
|
newOpts = c_api.TF_NewSessionOptions(); |
|
|
|
|
|
|
|
Status = new Status(); |
|
|
@@ -50,7 +51,7 @@ namespace Tensorflow |
|
|
|
_session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); |
|
|
|
|
|
|
|
// dispose newOpts |
|
|
|
if (opts == null)
|
|
|
|
if (opts == null) |
|
|
|
c_api.TF_DeleteSessionOptions(newOpts); |
|
|
|
|
|
|
|
Status.Check(true); |
|
|
@@ -63,7 +64,7 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public virtual NDArray run(object fetches, Hashtable feed_dict = null) |
|
|
|
{ |
|
|
|
var feed_items = feed_dict == null ? new FeedItem[0] :
|
|
|
|
var feed_items = feed_dict == null ? new FeedItem[0] : |
|
|
|
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); |
|
|
|
return _run(fetches, feed_items); |
|
|
|
} |
|
|
@@ -86,57 +87,8 @@ namespace Tensorflow |
|
|
|
foreach (var (subfeed, subfeed_val) in feed_fn(feed)) |
|
|
|
{ |
|
|
|
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); |
|
|
|
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); |
|
|
|
|
|
|
|
switch (subfeed_val) |
|
|
|
{ |
|
|
|
case IntPtr val: |
|
|
|
feed_dict_tensor[subfeed_t] = val; |
|
|
|
break; |
|
|
|
case NDArray val: |
|
|
|
feed_dict_tensor[subfeed_t] = val; |
|
|
|
break; |
|
|
|
case float val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case double val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case short val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case int val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break;
|
|
|
|
case long val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break;
|
|
|
|
case long[] val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case int[] val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case string val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case byte[] val: |
|
|
|
feed_dict_tensor[subfeed_t] = np.array(val); |
|
|
|
break;
|
|
|
|
case char[] val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case bool val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
case bool[] val: |
|
|
|
feed_dict_tensor[subfeed_t] = (NDArray)val; |
|
|
|
break; |
|
|
|
default: |
|
|
|
Console.WriteLine($"can't handle data type of subfeed_val"); |
|
|
|
throw new NotImplementedException("_run subfeed"); |
|
|
|
}
|
|
|
|
|
|
|
|
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used |
|
|
|
feed_dict_tensor[subfeed_t] = subfeed_val; |
|
|
|
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); |
|
|
|
} |
|
|
|
} |
|
|
@@ -175,26 +127,78 @@ namespace Tensorflow |
|
|
|
/// </returns> |
|
|
|
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) |
|
|
|
{ |
|
|
|
var feeds = feed_dict.Select(x =>
|
|
|
|
var feeds = feed_dict.Select(x => |
|
|
|
{ |
|
|
|
if (x.Key is Tensor tensor) |
|
|
|
{ |
|
|
|
switch (x.Value) |
|
|
|
{ |
|
|
|
case IntPtr pointer: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer); |
|
|
|
case Tensor t1: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1); |
|
|
|
case NDArray nd: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd, tensor.dtype)); |
|
|
|
case int intVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal)); |
|
|
|
case float floatVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal)); |
|
|
|
case double doubleVal: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal)); |
|
|
|
#if _REGEN |
|
|
|
%types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] |
|
|
|
%foreach types% |
|
|
|
case #1 v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case #1[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
% |
|
|
|
#else |
|
|
|
case sbyte v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case sbyte[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case byte v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case byte[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case short v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case short[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case ushort v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case ushort[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case int v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case int[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case uint v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case uint[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case long v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case long[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case ulong v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case ulong[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case float v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case float[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case double v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case double[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case Complex v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case Complex[] v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
#endif |
|
|
|
case bool v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); |
|
|
|
case string v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case IntPtr v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); |
|
|
|
case Tensor v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
|
|
|
|
case NDArray v: |
|
|
|
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); |
|
|
|
default: |
|
|
|
throw new NotImplementedException("feed_dict data type"); |
|
|
|
throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "<null>")}"); |
|
|
|
} |
|
|
|
} |
|
|
|
throw new NotImplementedException("_do_run.feed_dict"); |
|
|
|