|
- using NumSharp.Core;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Runtime.InteropServices;
- using System.Text;
-
- namespace Tensorflow
- {
- public class BaseSession
- {
- protected Graph _graph;
- protected bool _opened;
- protected bool _closed;
- protected int _current_version;
- protected byte[] _target;
- protected IntPtr _session;
-
- public BaseSession(string target = "", Graph graph = null)
- {
- if(graph is null)
- {
- _graph = ops.get_default_graph();
- }
- else
- {
- _graph = graph;
- }
-
- _target = UTF8Encoding.UTF8.GetBytes(target);
- var opts = c_api.TF_NewSessionOptions();
- var status = new Status();
- _session = c_api.TF_NewSession(_graph, opts, status);
-
- c_api.TF_DeleteSessionOptions(opts);
- }
-
- public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
- {
- return _run(fetches, feed_dict);
- }
-
- private NDArray _run(object fetches, FeedItem[] feed_dict = null)
- {
- var feed_dict_tensor = new Dictionary<object, object>();
- var feed_map = new Dictionary<object, object>();
-
- Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
- {
- return new (object, object)[] { (item.Key, item.Value) };
- };
-
- // Validate and process feed_dict.
- if (feed_dict != null)
- {
- foreach (var feed in feed_dict)
- {
- foreach (var (subfeed, subfeed_val) in feed_fn(feed))
- {
- var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
- var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
- switch (subfeed_val)
- {
- case IntPtr pointer:
- feed_dict_tensor[subfeed_t] = pointer;
- break;
- case NDArray nd:
- feed_dict_tensor[subfeed_t] = nd;
- break;
- case float floatVal:
- feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
- break;
- case int intVal:
- feed_dict_tensor[subfeed_t] = (NDArray)intVal;
- break;
- case string str:
- feed_dict_tensor[subfeed_t] = (NDArray)str;
- break;
- case byte[] bytes:
- feed_dict_tensor[subfeed_t] = (NDArray)bytes;
- break;
- default:
- throw new NotImplementedException("_run subfeed");
- }
- feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
- }
- }
- }
-
- // Create a fetch handler to take care of the structure of fetches.
- var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
-
- // Run request and get response.
- // We need to keep the returned movers alive for the following _do_run().
- // These movers are no longer needed when _do_run() completes, and
- // are deleted when `movers` goes out of scope when this _run() ends.
- var _ = _update_with_movers();
- var final_fetches = fetch_handler.fetches();
- var final_targets = fetch_handler.targets();
-
- // We only want to really perform the run if fetches or targets are provided,
- // or if the call is a partial run that specifies feeds.
- var results = _do_run(final_targets.Select(x => (Operation)(object)x).ToList(), final_fetches, feed_dict_tensor);
-
- return fetch_handler.build_results(this, results);
- }
-
- /// <summary>
- /// Runs a step based on the given fetches and feeds.
- /// </summary>
- /// <typeparam name="T"></typeparam>
- /// <param name="target_list">A list of operations to be run, but not fetched.</param>
- /// <param name="fetch_list"></param>
- /// <param name="feed_dict"></param>
- /// <returns>
- /// A list of numpy ndarrays, corresponding to the elements of
- /// `fetch_list`. If the ith element of `fetch_list` contains the
- /// name of an operation, the first Tensor output of that operation
- /// will be returned for that element.
- /// </returns>
- private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
- {
- var feeds = feed_dict.Select(x =>
- {
- if(x.Key is Tensor tensor)
- {
- switch (x.Value)
- {
- case IntPtr pointer:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), pointer);
- case Tensor t1:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
- case NDArray nd:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
- case int intVal:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
- case float floatVal:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
- case double doubleVal:
- return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
- default:
- throw new NotImplementedException("feed_dict data type");
- }
- }
- throw new NotImplementedException("_do_run.feed_dict");
- }).ToArray();
- var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
- var targets = target_list;
-
- return _call_tf_sessionrun(feeds, fetches, target_list);
- }
-
- private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
- {
- // Ensure any changes to the graph are reflected in the runtime.
- _extend_graph();
-
- var status = new Status();
-
- var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
-
- c_api.TF_SessionRun(_session,
- run_options: null,
- inputs: feed_dict.Select(f => f.Key).ToArray(),
- input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
- ninputs: feed_dict.Length,
- outputs: fetch_list,
- output_values: output_values,
- noutputs: fetch_list.Length,
- target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
- ntargets: target_list.Count,
- run_metadata: IntPtr.Zero,
- status: status);
-
- status.Check(true);
-
- var result = new NDArray[fetch_list.Length];
-
- for (int i = 0; i < fetch_list.Length; i++)
- {
- result[i] = fetchValue(output_values[i]);
- }
-
- return result;
- }
-
- private unsafe NDArray fetchValue(IntPtr output)
- {
- var tensor = new Tensor(output);
- NDArray nd = null;
- Type type = tensor.dtype.as_numpy_datatype();
- var ndims = tensor.shape.Select(x => (int)x).ToArray();
- var offset = c_api.TF_TensorData(output);
-
- switch (tensor.dtype)
- {
- case TF_DataType.TF_STRING:
- var bytes = tensor.Data();
- // wired, don't know why we have to start from offset 9.
- // length in the begin
- var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
- nd = np.array(str).reshape();
- break;
- case TF_DataType.TF_INT16:
- var shorts = new short[tensor.size];
- for (ulong i = 0; i < tensor.size; i++)
- shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
- nd = np.array(shorts).reshape(ndims);
- break;
- case TF_DataType.TF_INT32:
- var ints = new int[tensor.size];
- for (ulong i = 0; i < tensor.size; i++)
- ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
- nd = np.array(ints).reshape(ndims);
- break;
- case TF_DataType.TF_FLOAT:
- var floats = new float[tensor.size];
- for (ulong i = 0; i < tensor.size; i++)
- floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
- nd = np.array(floats).reshape(ndims);
- break;
- case TF_DataType.TF_DOUBLE:
- var doubles = new double[tensor.size];
- for (ulong i = 0; i < tensor.size; i++)
- doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
- nd = np.array(doubles).reshape(ndims);
- break;
- default:
- throw new NotImplementedException("can't fetch output");
- }
-
- return nd;
- }
-
- /// <summary>
- /// If a tensor handle that is fed to a device incompatible placeholder,
- /// we move the tensor to the right device, generate a new tensor handle,
- /// and update feed_dict to use the new handle.
- /// </summary>
- private List<object> _update_with_movers()
- {
- return new List<object> { };
- }
-
- private void _extend_graph()
- {
-
- }
- }
- }
|