using NumSharp.Core; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; namespace Tensorflow { public class BaseSession { protected Graph _graph; protected bool _opened; protected bool _closed; protected int _current_version; protected byte[] _target; protected IntPtr _session; public BaseSession(string target = "", Graph graph = null) { if(graph is null) { _graph = ops.get_default_graph(); } else { _graph = graph; } _target = UTF8Encoding.UTF8.GetBytes(target); var opts = c_api.TF_NewSessionOptions(); var status = new Status(); _session = c_api.TF_NewSession(_graph, opts, status); c_api.TF_DeleteSessionOptions(opts); } public virtual NDArray run(object fetches, params FeedItem[] feed_dict) { return _run(fetches, feed_dict); } private NDArray _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); var feed_map = new Dictionary(); Func> feed_fn = (item) => { return new (object, object)[] { (item.Key, item.Value) }; }; // Validate and process feed_dict. if (feed_dict != null) { foreach (var feed in feed_dict) { foreach (var (subfeed, subfeed_val) in feed_fn(feed)) { var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); switch (subfeed_val) { case IntPtr pointer: feed_dict_tensor[subfeed_t] = pointer; break; case NDArray nd: feed_dict_tensor[subfeed_t] = nd; break; case float floatVal: feed_dict_tensor[subfeed_t] = (NDArray)floatVal; break; case int intVal: feed_dict_tensor[subfeed_t] = (NDArray)intVal; break; case string str: feed_dict_tensor[subfeed_t] = (NDArray)str; break; case byte[] bytes: feed_dict_tensor[subfeed_t] = (NDArray)bytes; break; default: throw new NotImplementedException("_run subfeed"); } feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); } } } // Create a fetch handler to take care of the structure of fetches. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); // Run request and get response. // We need to keep the returned movers alive for the following _do_run(). // These movers are no longer needed when _do_run() completes, and // are deleted when `movers` goes out of scope when this _run() ends. var _ = _update_with_movers(); var final_fetches = fetch_handler.fetches(); var final_targets = fetch_handler.targets(); // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor); return fetch_handler.build_results(this, results); } /// /// Runs a step based on the given fetches and feeds. /// /// /// A list of operations to be run, but not fetched. /// /// /// /// A list of numpy ndarrays, corresponding to the elements of /// `fetch_list`. If the ith element of `fetch_list` contains the /// name of an operation, the first Tensor output of that operation /// will be returned for that element. /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { var feeds = feed_dict.Select(x => { if(x.Key is Tensor tensor) { switch (x.Value) { case IntPtr pointer: return new KeyValuePair(tensor._as_tf_output(), pointer); case Tensor t1: return new KeyValuePair(tensor._as_tf_output(), t1); case NDArray nd: return new KeyValuePair(tensor._as_tf_output(), new Tensor(nd)); case int intVal: return new KeyValuePair(tensor._as_tf_output(), new Tensor(intVal)); case float floatVal: return new KeyValuePair(tensor._as_tf_output(), new Tensor(floatVal)); case double doubleVal: return new KeyValuePair(tensor._as_tf_output(), new Tensor(doubleVal)); default: throw new NotImplementedException("feed_dict data type"); } } throw new NotImplementedException("_do_run.feed_dict"); }).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var targets = target_list; return _call_tf_sessionrun(feeds, fetches, target_list); } private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) { // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); var status = new Status(); var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_session, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); status.Check(true); var result = new NDArray[fetch_list.Length]; for (int i = 0; i < fetch_list.Length; i++) { result[i] = fetchValue(output_values[i]); } return result; } private unsafe NDArray fetchValue(IntPtr output) { var tensor = new Tensor(output); NDArray nd = null; Type type = tensor.dtype.as_numpy_datatype(); var ndims = tensor.shape.Select(x => (int)x).ToArray(); var offset = c_api.TF_TensorData(output); switch (tensor.dtype) { case TF_DataType.TF_STRING: var bytes = tensor.Data(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); nd = np.array(str).reshape(); break; case TF_DataType.TF_INT16: var shorts = new short[tensor.size]; for (ulong i = 0; i < tensor.size; i++) shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); nd = np.array(shorts).reshape(ndims); break; case TF_DataType.TF_INT32: var ints = new int[tensor.size]; for (ulong i = 0; i < tensor.size; i++) ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); nd = np.array(ints).reshape(ndims); break; case TF_DataType.TF_FLOAT: var floats = new float[tensor.size]; for (ulong i = 0; i < tensor.size; i++) floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); nd = np.array(floats).reshape(ndims); break; case TF_DataType.TF_DOUBLE: var doubles = new double[tensor.size]; for (ulong i = 0; i < tensor.size; i++) doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); nd = np.array(doubles).reshape(ndims); break; default: throw new NotImplementedException("can't fetch output"); } return nd; } /// /// If a tensor handle that is fed to a device incompatible placeholder, /// we move the tensor to the right device, generate a new tensor handle, /// and update feed_dict to use the new handle. /// private List _update_with_movers() { return new List { }; } private void _extend_graph() { } } }