|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- using System;
- using System.Collections.Generic;
- using System.Runtime.InteropServices;
- using System.Text;
- using System.Threading;
- using Tensorflow;
- using node_def_pb2 = Tensorflow;
- using Google.Protobuf;
- using System.Linq;
- using NumSharp.Core;
- using System.ComponentModel;
-
- namespace Tensorflow
- {
- public partial class ops
- {
- public static void add_to_collection<T>(string name, T value)
- {
- var graph = tf.get_default_graph();
- graph.add_to_collection(name, value);
- }
-
- public static void add_to_collections<T>(List<string> names, T value)
- {
- var graph = tf.get_default_graph();
- graph.add_to_collections(names, value);
- }
-
- /// <summary>
- /// Wrapper for `Graph.get_collection()` using the default graph.
- /// contains many standard names for collections.
- /// </summary>
- /// <param name="key">
- /// The key for the collection. For example, the `GraphKeys` class
- /// </param>
- /// <param name="scope"></param>
- /// <returns>
- /// The list of values in the collection with the given `name`, or
- /// an empty list if no value has been added to that collection. The
- /// list contains the values in the order under which they were
- /// collected.
- /// </returns>
- public static object get_collection(string key, string scope = "")
- {
- return get_default_graph().get_collection(key, scope);
- }
-
- public static Graph get_default_graph()
- {
- return tf.Graph();
- }
-
- public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
- {
- foreach(var op_input in op_input_list)
- {
- // Determine if this is a valid graph_element.
- var graph_element = op_input;
- }
-
- return get_default_graph();
- }
-
- /// <summary>
- /// Converts the given `value` to a `Tensor`.
- /// </summary>
- /// <param name="value"></param>
- /// <param name="dtype"></param>
- /// <param name="name"></param>
- /// <returns></returns>
- public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
- {
- switch (value)
- {
- case Tensor val:
- return val;
- default:
- var nd = tensor_util.convert_to_numpy_ndarray(value);
- return constant_op.constant(nd, name);
- }
- }
-
- public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
- {
- return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false);
- }
-
- public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
- {
- return internal_convert_to_tensor<Tensor>(value, dtype: dtype.as_datatype_enum(), name: name, as_ref: as_ref);
- }
-
- /// <summary>
- /// Wrapper for `Graph.control_dependencies()` using the default graph.
- /// </summary>
- /// <param name="control_inputs"></param>
- public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
- {
- return get_default_graph().control_dependencies(control_inputs);
- }
-
- /// <summary>
- /// Creates a TF_Operation.
- /// </summary>
- /// <param name="graph">a `Graph`.</param>
- /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
- /// <param name="inputs">
- /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
- /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
- /// "list(int64)"). The length of the list should be equal to the number of
- /// inputs specified by this operation's op def.
- /// </param>
- /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
- /// <returns>A wrapped TF_Operation*.</returns>
- public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
- {
- var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
-
- // Add inputs
- if(inputs != null)
- {
- foreach (var op_input in inputs)
- {
- if (op_input is Tensor[] op_inputs)
- c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
- else if (op_input is Tensor op_input1)
- c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
- else
- throw new NotImplementedException("_create_c_op");
- }
- }
-
- var status = new Status();
-
- // Add control inputs
- foreach (var control_input in control_inputs)
- c_api.TF_AddControlInput(op_desc, control_input);
-
- // Add attrs
- foreach (var attr in node_def.Attr)
- {
- var bytes = attr.Value.ToByteArray();
- var proto = Marshal.AllocHGlobal(bytes.Length);
- Marshal.Copy(bytes, 0, proto, bytes.Length);
-
- c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
-
- status.Check(true);
- }
-
- var c_op = c_api.TF_FinishOperation(op_desc, status);
-
- status.Check(true);
-
- return c_op;
- }
-
- public static OpDef _get_op_def(Graph graph, string type)
- {
- return graph.GetOpDef(type);
- }
-
- public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
- {
- var node_def = new node_def_pb2.NodeDef();
- node_def.Op = op_type;
- node_def.Name = name;
-
- foreach (var attr in attrs)
- {
- node_def.Attr.Add(attr.Key, attr.Value);
- }
-
- return node_def;
- }
-
- public static string _name_from_scope_name(string name)
- {
- if (name.EndsWith("/"))
- {
- return name.Substring(0, name.Length - 1);
- }
- else
- {
- return name;
- }
- }
-
- /// <summary>
- /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
- /// </summary>
- /// <returns></returns>
- public static void init_scope()
- {
- // Retrieve the active name scope: entering an `init_scope` preserves
- // the name scope of the current context.
- var default_graph = get_default_graph();
- var scope = default_graph.get_name_scope();
- if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
- // Names that end with trailing slashes are treated by `name_scope` as
- // absolute.
- scope += "/";
- // inner_device_stack = default_graph._device_function_stack
- // var outer_context = default_graph.as_default;
-
- Python.with(ops.control_dependencies(null), delegate
- {
- var outer_graph = get_default_graph();
- // outer_device_stack = None
- });
- }
-
- private static int uid_number = 0;
-
- /// <summary>
- /// A unique (within this program execution) integer.
- /// Not thread safe
- /// </summary>
- /// <returns></returns>
- public static int uid()
- {
- return uid_number++;
- }
-
- public static void colocate_with(Operation op, bool ignore_existing = false)
- {
- _colocate_with_for_gradient(op, null, ignore_existing);
- }
-
- public static void colocate_with(Tensor tensor, bool ignore_existing = false)
- {
- _colocate_with_for_gradient(tensor.op, null, ignore_existing);
- }
-
- public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false)
- {
- var default_graph = get_default_graph();
- default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
- }
-
- /// <summary>
- /// Uses the default session to evaluate one or more tensors.
- /// </summary>
- /// <param name="tensors">A single Tensor, or a list of Tensor objects.</param>
- /// <param name="feed_dict">
- /// A dictionary that maps Tensor objects (or tensor names) to lists,
- /// numpy ndarrays, TensorProtos, or strings.
- /// </param>
- /// <param name="graph">The graph in which the tensors are defined.</param>
- /// <param name="session">A different session to use to evaluate "tensors".</param>
- /// <returns>
- /// Either a single numpy ndarray if "tensors" is a single tensor; or a list
- /// of numpy ndarrays that each correspond to the respective element in
- /// "tensors".
- /// </returns>
- public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
- {
- if (session == null)
- {
- session = get_default_session();
-
- if (session == null)
- throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
- "session is registered. Use `with " +
- "sess.as_default()` or pass an explicit session to " +
- "`eval(session=sess)`");
-
- if (session.graph != graph)
- throw new ValueError("Cannot use the default session to evaluate tensor: " +
- "the tensor's graph is different from the session's " +
- "graph. Pass an explicit session to " +
- "`eval(session=sess)`.");
- }
- else
- {
- if (session.graph != graph)
- throw new ValueError("Cannot use the default session to evaluate tensor: " +
- "the tensor's graph is different from the session's " +
- "graph. Pass an explicit session to " +
- "`eval(session=sess)`.");
- }
-
- return session.run(tensor, feed_dict);
- }
-
- /// <summary>
- /// Returns the default session for the current thread.
- /// </summary>
- /// <returns>The default `Session` being used in the current thread.</returns>
- public static Session get_default_session()
- {
- return tf.Session();
- }
-
- public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op)
- {
- if (op.inputs == null) return null;
-
- return (oper, out_grads) =>
- {
- Console.WriteLine($"get_gradient_function: {oper.type} '{oper.Name}'");
-
- switch (oper.type)
- {
- case "Add":
- var add = math_grad._AddGrad(oper, out_grads);
- return new Tensor[] { add.Item1, add.Item2 };
- case "Identity":
- var id = math_grad._IdGrad(oper, out_grads);
- return new Tensor[] { id };
- case "Mul":
- var mul = math_grad._MulGrad(oper, out_grads);
- return new Tensor[] { mul.Item1, mul.Item2 };
- case "Sum":
- var sum = math_grad._SumGrad(oper, out_grads);
- return new Tensor[] { sum.Item1, sum.Item2 };
- case "Sub":
- var sub = math_grad._SubGrad(oper, out_grads);
- return new Tensor[] { sub.Item1, sub.Item2 };
- case "Pow":
- var pow = math_grad._PowGrad(oper, out_grads);
- return new Tensor[] { pow.Item1, pow.Item2 };
- case "RealDiv":
- var realdiv = math_grad._RealDivGrad(oper, out_grads);
- return new Tensor[] { realdiv.Item1, realdiv.Item2 };
- default:
- throw new NotImplementedException($"get_gradient_function {oper.type}");
- }
- };
- }
-
- public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
- {
- return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
- }
-
- public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
- {
- return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
- }
-
- public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
- {
- return value;
- }
-
- public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
- {
- var ret = new List<Tensor>();
-
- foreach(var (i, value) in Python.enumerate(values))
- {
- if (value == null)
- {
- ret.Add(value);
- }
- else
- {
- var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
- ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
- }
- }
-
- return ret.ToArray();
- }
-
- public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
- string name = "", DataType preferred_dtype = DataType.DtInvalid,
- bool as_ref = false)
- {
- var ret = new List<Tensor>();
-
- foreach((int i, T value) in Python.enumerate(values))
- {
- string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
- ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
- }
-
- return ret.ToArray();
- }
-
- public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
- string name = "", DataType preferred_dtype = DataType.DtInvalid,
- bool as_ref = false)
- {
- switch (typeof(T).Name)
- {
- case "Tensor":
- return value as Tensor;
- case "String":
- return constant_op.constant(Convert.ToString(value), name);
- case "String[]":
- return constant_op.constant(value as string[], name);
- case "Int32":
- return constant_op.constant(Convert.ToInt32(value), name);
- case "Double":
- return constant_op.constant(Convert.ToDouble(value), name);
- case "RefVariable":
- return (value as RefVariable)._TensorConversionFunction(as_ref: as_ref);
- default:
- throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor");
- }
- }
- }
- }
|