using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; using System.Threading; using Tensorflow; using node_def_pb2 = Tensorflow; using Google.Protobuf; using System.Linq; namespace Tensorflow { public static class ops { public static Graph get_default_graph() { return tf.Graph(); } public static Tensor convert_to_tensor(object value, string name = "") { var nd = tensor_util.convert_to_numpy_ndarray(value); return tf.constant(nd, name); } public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { var op_desc = graph.NewOperation(node_def.Op, node_def.Name); // Add inputs if(inputs != null) { foreach (var op_input in inputs) { bool isList = false; if (!isList) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); } else { c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); } } } var status = new Status(); // Add control inputs // Add attrs foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); status.Check(true); } var c_op = c_api.TF_FinishOperation(op_desc, status); if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); return c_op; } public static OpDef _get_op_def(Graph graph, string type) { return graph.GetOpDef(type); } public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary attrs = null) { var node_def = new node_def_pb2.NodeDef(); node_def.Op = op_type; node_def.Name = name; foreach (var attr in attrs) { node_def.Attr.Add(attr.Key, attr.Value); } return node_def; } public static string name_scope(string name, string default_name = "", object values = null) { string _name = ""; if (String.IsNullOrEmpty(name)) { _name = default_name; } var g = get_default_graph(); var _name_scope = g.name_scope(_name); return _name_scope; } public static string _name_from_scope_name(string name) { if (name.EndsWith("/")) { return name.Substring(0, name.Length - 1); } else { return name; } } public static int uid() { return 1; } } }