Browse Source

Merge pull request #247 from pkingwsd/master

fix call sessionoptions issue
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
348f418d9c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 303 additions and 298 deletions
  1. +297
    -293
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +6
    -5
      src/TensorFlowNET.Core/Sessions/Session.cs

+ 297
- 293
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -1,297 +1,301 @@
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public class BaseSession
{
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
protected IntPtr _session;

public BaseSession(string target = "", Graph graph = null)
{
if (graph is null)
{
_graph = ops.get_default_graph();
}
else
{
_graph = graph;
}

_target = UTF8Encoding.UTF8.GetBytes(target);
var opts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph, opts, status);

c_api.TF_DeleteSessionOptions(opts);
}

public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}

public virtual NDArray run(object fetches, Hashtable feed_dict = null)
{
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{
public class BaseSession
{
protected Graph _graph;
protected bool _opened;
protected bool _closed;
protected int _current_version;
protected byte[] _target;
protected IntPtr _session;
public BaseSession(string target , Graph graph ,SessionOptions opts)
{
if (graph is null)
{
_graph = ops.get_default_graph();
}
else
{
_graph = graph;
}
_target = UTF8Encoding.UTF8.GetBytes(target);
SessionOptions newOpts = null;
if (opts == null)
newOpts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph, opts?? newOpts, status);
if (opts == null)
c_api.TF_DeleteSessionOptions(newOpts);
}
public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}
public virtual NDArray run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();
Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};
// Validate and process feed_dict.
if (feed_dict != null)
{
foreach (var feed in feed_dict)
{
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch (subfeed_val)
{
case IntPtr val:
feed_dict_tensor[subfeed_t] = val;
break;
case NDArray val:
feed_dict_tensor[subfeed_t] = val;
break;
case float val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case double val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case short val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case string val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case byte[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
default:
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();
Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};
// Validate and process feed_dict.
if (feed_dict != null)
{
foreach (var feed in feed_dict)
{
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch (subfeed_val)
{
case IntPtr val:
feed_dict_tensor[subfeed_t] = val;
break;
case NDArray val:
feed_dict_tensor[subfeed_t] = val;
break;
case float val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case double val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case short val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case string val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case byte[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
default:
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");
}
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
}
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
return fetch_handler.build_results(this, results);
}
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
}
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
return fetch_handler.build_results(this, results);
}
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x =>
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
case IntPtr pointer:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
case Tensor t1:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
case NDArray nd:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
case int intVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
case float floatVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
case double doubleVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
default:
throw new NotImplementedException("feed_dict data type");
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
}
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
var status = new Status();
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
c_api.TF_SessionRun(_session,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status);
status.Check(true);
var result = new NDArray[fetch_list.Length];
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]);
for (int i = 0; i < feed_dict.Length; i++)
feed_dict[i].Value.Dispose();
return result;
}
private unsafe NDArray fetchValue(IntPtr output)
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_datatype();
var ndims = tensor.shape.Select(x => (int)x).ToArray();
var offset = c_api.TF_TensorData(output);
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
}
tensor.Dispose();
return nd;
}
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private void _extend_graph()
{
}
}
}
{
if (x.Key is Tensor tensor)
{
switch (x.Value)
{
case IntPtr pointer:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
case Tensor t1:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
case NDArray nd:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
case int intVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
case float floatVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
case double doubleVal:
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
default:
throw new NotImplementedException("feed_dict data type");
}
}
throw new NotImplementedException("_do_run.feed_dict");
}).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
}
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
var status = new Status();
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
c_api.TF_SessionRun(_session,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status);
status.Check(true);
var result = new NDArray[fetch_list.Length];
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(output_values[i]);
for (int i = 0; i < feed_dict.Length; i++)
feed_dict[i].Value.Dispose();
return result;
}
private unsafe NDArray fetchValue(IntPtr output)
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_datatype();
var ndims = tensor.shape.Select(x => (int)x).ToArray();
var offset = c_api.TF_TensorData(output);
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
var bools = new bool[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
nd = np.array(shorts).reshape(ndims);
break;
case TF_DataType.TF_INT32:
var ints = new int[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
nd = np.array(ints).reshape(ndims);
break;
case TF_DataType.TF_INT64:
var longs = new long[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
nd = np.array(longs).reshape(ndims);
break;
case TF_DataType.TF_FLOAT:
var floats = new float[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
nd = np.array(floats).reshape(ndims);
break;
case TF_DataType.TF_DOUBLE:
var doubles = new double[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
nd = np.array(doubles).reshape(ndims);
break;
default:
throw new NotImplementedException("can't fetch output");
}
tensor.Dispose();
return nd;
}
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private void _extend_graph()
{
}
}
}

+ 6
- 5
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow
public SessionOptions Options { get; }
public Graph graph;

public Session(string target = "", Graph graph = null)
public Session(string target = "", Graph graph = null):base(target,graph,null)
{
if(graph == null)
{
@@ -20,20 +20,21 @@ namespace Tensorflow
this.graph = graph;
Options = new SessionOptions();
_handle = c_api.TF_NewSession(graph, Options, Status);
//why create session again. already created session in BaseSession.
Status.Check();
}

public Session(IntPtr handle)
public Session(IntPtr handle):base("",null,null)
{
_handle = handle;
}

public Session(Graph g, SessionOptions opts = null, Status s = null)
:base(string.Empty,g,opts)
{
if (s == null)
s = Status;
s = s ?? Status;
graph = g;
Options = opts == null ? new SessionOptions() : opts;
Options = opts ?? new SessionOptions();
_handle = c_api.TF_NewSession(graph, Options, s);
Status.Check(true);
}


Loading…
Cancel
Save