From 8bf4edb713891a89a3a3a92afac7466e2b7bed55 Mon Sep 17 00:00:00 2001 From: pkingwsd <7557104.com> Date: Thu, 16 May 2019 13:22:08 +0800 Subject: [PATCH] fix call sessionoptions issue --- .../Sessions/BaseSession.cs | 590 +++++++++--------- src/TensorFlowNET.Core/Sessions/Session.cs | 11 +- 2 files changed, 303 insertions(+), 298 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 22339226..b8e9f1e1 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,297 +1,301 @@ -using NumSharp; -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; - -namespace Tensorflow -{ - public class BaseSession - { - protected Graph _graph; - protected bool _opened; - protected bool _closed; - protected int _current_version; - protected byte[] _target; - protected IntPtr _session; - - public BaseSession(string target = "", Graph graph = null) - { - if (graph is null) - { - _graph = ops.get_default_graph(); - } - else - { - _graph = graph; - } - - _target = UTF8Encoding.UTF8.GetBytes(target); - var opts = c_api.TF_NewSessionOptions(); - var status = new Status(); - _session = c_api.TF_NewSession(_graph, opts, status); - - c_api.TF_DeleteSessionOptions(opts); - } - - public virtual NDArray run(object fetches, params FeedItem[] feed_dict) - { - return _run(fetches, feed_dict); - } - - public virtual NDArray run(object fetches, Hashtable feed_dict = null) - { +using NumSharp; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public class BaseSession + { + protected Graph _graph; + protected bool _opened; + protected bool _closed; + protected int _current_version; + protected byte[] _target; + protected IntPtr _session; + + + public BaseSession(string target , Graph graph ,SessionOptions opts) + { + if (graph is null) + { + _graph = ops.get_default_graph(); + } + else + { + _graph = graph; + } + + _target = UTF8Encoding.UTF8.GetBytes(target); + SessionOptions newOpts = null; + if (opts == null) + newOpts = c_api.TF_NewSessionOptions(); + var status = new Status(); + _session = c_api.TF_NewSession(_graph, opts?? newOpts, status); + + if (opts == null) + c_api.TF_DeleteSessionOptions(newOpts); + } + + public virtual NDArray run(object fetches, params FeedItem[] feed_dict) + { + return _run(fetches, feed_dict); + } + + public virtual NDArray run(object fetches, Hashtable feed_dict = null) + { var feed_items = feed_dict == null ? new FeedItem[0] : - feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); - return _run(fetches, feed_items); - } - - private NDArray _run(object fetches, FeedItem[] feed_dict = null) - { - var feed_dict_tensor = new Dictionary(); - var feed_map = new Dictionary(); - - Func> feed_fn = (item) => - { - return new (object, object)[] { (item.Key, item.Value) }; - }; - - // Validate and process feed_dict. - if (feed_dict != null) - { - foreach (var feed in feed_dict) - { - foreach (var (subfeed, subfeed_val) in feed_fn(feed)) - { - var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); - - switch (subfeed_val) - { - case IntPtr val: - feed_dict_tensor[subfeed_t] = val; - break; - case NDArray val: - feed_dict_tensor[subfeed_t] = val; - break; - case float val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case double val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case short val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case int val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case int[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case string val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case byte[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case bool val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - case bool[] val: - feed_dict_tensor[subfeed_t] = (NDArray)val; - break; - default: - Console.WriteLine($"can't handle data type of subfeed_val"); - throw new NotImplementedException("_run subfeed"); + feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + return _run(fetches, feed_items); + } + + private NDArray _run(object fetches, FeedItem[] feed_dict = null) + { + var feed_dict_tensor = new Dictionary(); + var feed_map = new Dictionary(); + + Func> feed_fn = (item) => + { + return new (object, object)[] { (item.Key, item.Value) }; + }; + + // Validate and process feed_dict. + if (feed_dict != null) + { + foreach (var feed in feed_dict) + { + foreach (var (subfeed, subfeed_val) in feed_fn(feed)) + { + var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); + var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); + + switch (subfeed_val) + { + case IntPtr val: + feed_dict_tensor[subfeed_t] = val; + break; + case NDArray val: + feed_dict_tensor[subfeed_t] = val; + break; + case float val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case double val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case short val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case int val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case int[] val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case string val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case byte[] val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case bool val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case bool[] val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + default: + Console.WriteLine($"can't handle data type of subfeed_val"); + throw new NotImplementedException("_run subfeed"); } - feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); - } - } - } - - // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); - - // Run request and get response. - // We need to keep the returned movers alive for the following _do_run(). - // These movers are no longer needed when _do_run() completes, and - // are deleted when `movers` goes out of scope when this _run() ends. - var _ = _update_with_movers(); - var final_fetches = fetch_handler.fetches(); - var final_targets = fetch_handler.targets(); - - // We only want to really perform the run if fetches or targets are provided, - // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); - - return fetch_handler.build_results(this, results); - } - - /// - /// Runs a step based on the given fetches and feeds. - /// - /// - /// A list of operations to be run, but not fetched. - /// - /// - /// - /// A list of numpy ndarrays, corresponding to the elements of - /// `fetch_list`. If the ith element of `fetch_list` contains the - /// name of an operation, the first Tensor output of that operation - /// will be returned for that element. - /// - private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) - { + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); + } + } + } + + // Create a fetch handler to take care of the structure of fetches. + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); + + // Run request and get response. + // We need to keep the returned movers alive for the following _do_run(). + // These movers are no longer needed when _do_run() completes, and + // are deleted when `movers` goes out of scope when this _run() ends. + var _ = _update_with_movers(); + var final_fetches = fetch_handler.fetches(); + var final_targets = fetch_handler.targets(); + + // We only want to really perform the run if fetches or targets are provided, + // or if the call is a partial run that specifies feeds. + var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); + + return fetch_handler.build_results(this, results); + } + + /// + /// Runs a step based on the given fetches and feeds. + /// + /// + /// A list of operations to be run, but not fetched. + /// + /// + /// + /// A list of numpy ndarrays, corresponding to the elements of + /// `fetch_list`. If the ith element of `fetch_list` contains the + /// name of an operation, the first Tensor output of that operation + /// will be returned for that element. + /// + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) + { var feeds = feed_dict.Select(x => - { - if (x.Key is Tensor tensor) - { - switch (x.Value) - { - case IntPtr pointer: - return new KeyValuePair(tensor._as_tf_output(), pointer); - case Tensor t1: - return new KeyValuePair(tensor._as_tf_output(), t1); - case NDArray nd: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(nd)); - case int intVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(intVal)); - case float floatVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(floatVal)); - case double doubleVal: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(doubleVal)); - default: - throw new NotImplementedException("feed_dict data type"); - } - } - throw new NotImplementedException("_do_run.feed_dict"); - }).ToArray(); - var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - var targets = target_list; - - return _call_tf_sessionrun(feeds, fetches, target_list); - } - - private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) - { - // Ensure any changes to the graph are reflected in the runtime. - _extend_graph(); - - var status = new Status(); - - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - - c_api.TF_SessionRun(_session, - run_options: null, - inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), - ninputs: feed_dict.Length, - outputs: fetch_list, - output_values: output_values, - noutputs: fetch_list.Length, - target_opers: target_list.Select(f => (IntPtr)f).ToArray(), - ntargets: target_list.Count, - run_metadata: IntPtr.Zero, - status: status); - - status.Check(true); - - var result = new NDArray[fetch_list.Length]; - - for (int i = 0; i < fetch_list.Length; i++) - result[i] = fetchValue(output_values[i]); - - for (int i = 0; i < feed_dict.Length; i++) - feed_dict[i].Value.Dispose(); - - return result; - } - - private unsafe NDArray fetchValue(IntPtr output) - { - var tensor = new Tensor(output); - NDArray nd = null; - Type type = tensor.dtype.as_numpy_datatype(); - var ndims = tensor.shape.Select(x => (int)x).ToArray(); - var offset = c_api.TF_TensorData(output); - - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - var bools = new bool[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(bools).reshape(ndims); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.Data(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str).reshape(); - break; - case TF_DataType.TF_UINT8: - var _bytes = new byte[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(_bytes).reshape(ndims); - break; - case TF_DataType.TF_INT16: - var shorts = new short[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(shorts).reshape(ndims); - break; - case TF_DataType.TF_INT32: - var ints = new int[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(ints).reshape(ndims); - break; - case TF_DataType.TF_INT64: - var longs = new long[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(longs).reshape(ndims); - break; - case TF_DataType.TF_FLOAT: - var floats = new float[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(floats).reshape(ndims); - break; - case TF_DataType.TF_DOUBLE: - var doubles = new double[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(doubles).reshape(ndims); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - - tensor.Dispose(); - - return nd; - } - - /// - /// If a tensor handle that is fed to a device incompatible placeholder, - /// we move the tensor to the right device, generate a new tensor handle, - /// and update feed_dict to use the new handle. - /// - private List _update_with_movers() - { - return new List { }; - } - - private void _extend_graph() - { - - } - } -} + { + if (x.Key is Tensor tensor) + { + switch (x.Value) + { + case IntPtr pointer: + return new KeyValuePair(tensor._as_tf_output(), pointer); + case Tensor t1: + return new KeyValuePair(tensor._as_tf_output(), t1); + case NDArray nd: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(nd)); + case int intVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(intVal)); + case float floatVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(floatVal)); + case double doubleVal: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(doubleVal)); + default: + throw new NotImplementedException("feed_dict data type"); + } + } + throw new NotImplementedException("_do_run.feed_dict"); + }).ToArray(); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + var targets = target_list; + + return _call_tf_sessionrun(feeds, fetches, target_list); + } + + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) + { + // Ensure any changes to the graph are reflected in the runtime. + _extend_graph(); + + var status = new Status(); + + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); + + c_api.TF_SessionRun(_session, + run_options: null, + inputs: feed_dict.Select(f => f.Key).ToArray(), + input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + ninputs: feed_dict.Length, + outputs: fetch_list, + output_values: output_values, + noutputs: fetch_list.Length, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + ntargets: target_list.Count, + run_metadata: IntPtr.Zero, + status: status); + + status.Check(true); + + var result = new NDArray[fetch_list.Length]; + + for (int i = 0; i < fetch_list.Length; i++) + result[i] = fetchValue(output_values[i]); + + for (int i = 0; i < feed_dict.Length; i++) + feed_dict[i].Value.Dispose(); + + return result; + } + + private unsafe NDArray fetchValue(IntPtr output) + { + var tensor = new Tensor(output); + NDArray nd = null; + Type type = tensor.dtype.as_numpy_datatype(); + var ndims = tensor.shape.Select(x => (int)x).ToArray(); + var offset = c_api.TF_TensorData(output); + + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + var bools = new bool[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(bools).reshape(ndims); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.Data(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = np.array(str).reshape(); + break; + case TF_DataType.TF_UINT8: + var _bytes = new byte[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(_bytes).reshape(ndims); + break; + case TF_DataType.TF_INT16: + var shorts = new short[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(shorts).reshape(ndims); + break; + case TF_DataType.TF_INT32: + var ints = new int[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(ints).reshape(ndims); + break; + case TF_DataType.TF_INT64: + var longs = new long[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(longs).reshape(ndims); + break; + case TF_DataType.TF_FLOAT: + var floats = new float[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(floats).reshape(ndims); + break; + case TF_DataType.TF_DOUBLE: + var doubles = new double[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(doubles).reshape(ndims); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + + tensor.Dispose(); + + return nd; + } + + /// + /// If a tensor handle that is fed to a device incompatible placeholder, + /// we move the tensor to the right device, generate a new tensor handle, + /// and update feed_dict to use the new handle. + /// + private List _update_with_movers() + { + return new List { }; + } + + private void _extend_graph() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index a610f4e7..2e9e0d96 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -11,7 +11,7 @@ namespace Tensorflow public SessionOptions Options { get; } public Graph graph; - public Session(string target = "", Graph graph = null) + public Session(string target = "", Graph graph = null):base(target,graph,null) { if(graph == null) { @@ -20,20 +20,21 @@ namespace Tensorflow this.graph = graph; Options = new SessionOptions(); _handle = c_api.TF_NewSession(graph, Options, Status); + //why create session again. already created session in BaseSession. Status.Check(); } - public Session(IntPtr handle) + public Session(IntPtr handle):base("",null,null) { _handle = handle; } public Session(Graph g, SessionOptions opts = null, Status s = null) + :base(string.Empty,g,opts) { - if (s == null) - s = Status; + s = s ?? Status; graph = g; - Options = opts == null ? new SessionOptions() : opts; + Options = opts ?? new SessionOptions(); _handle = c_api.TF_NewSession(graph, Options, s); Status.Check(true); }