using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; namespace Tensorflow { /// /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. /// This leads to a low-level programming model in which you first define the dataflow graph, /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// https://www.tensorflow.org/guide/graphs /// public partial class Graph : IPython, IDisposable { private IntPtr _handle; private Dictionary _nodes_by_id; public Dictionary _nodes_by_name; private Dictionary _names_in_use; public int _version; private int _next_id_counter; private List _unfetchable_ops = new List(); public string _name_stack = ""; public string _graph_key; public Status Status { get; } /// /// True if the graph is considered "finalized". In that case no /// new operations can be added. /// private bool _finalized = false; /// /// Arbitrary collections of objects. /// private Dictionary _collections = new Dictionary(); public Graph() { _handle = c_api.TF_NewGraph(); Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; } public Graph(IntPtr handle) { _handle = handle; Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; } public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); } public Graph as_default() => ops.set_default_graph(this); private Tensor _as_graph_element(object obj) { if (obj is RefVariable var) return var._as_graph_element(); return null; } private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) { string types_str = ""; if (allow_tensor && allow_operation) { types_str = "Tensor or Operation"; } else if (allow_tensor) { types_str = "Tensor"; } else if (allow_operation) { types_str = "Operation"; } var temp_obj = _as_graph_element(obj); if (temp_obj != null) obj = temp_obj; // If obj appears to be a name... if (obj is string name) { if(name.Contains(":") && allow_tensor) { string op_name = name.Split(':')[0]; int out_n = int.Parse(name.Split(':')[1]); if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; } else if(!name.Contains(":") & allow_operation) { if (!_nodes_by_name.ContainsKey(name)) throw new KeyError($"The name {name} refers to an Operation not in the graph."); return _nodes_by_name[name]; } else if (!name.Contains(":") & !allow_operation) { throw new NotImplementedException("_as_graph_element_locked"); } } if (obj is Tensor tensor && allow_tensor) { if (tensor.graph.Equals(this)) { return tensor; } else { throw new Exception($"Tensor {obj} is not an element of this graph."); } } else if (obj is Operation op && allow_operation) { if (op.graph.Equals(this)) { return op; } else { throw new Exception($"Operation {obj} is not an element of this graph."); } } throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } public void add_to_collection(string name, T value) { _check_not_finalized(); if (_collections.ContainsKey(name)) (_collections[name] as List).Add(value); else _collections[name] = new List { value }; } public void add_to_collections(List names, T value) { foreach (string name in names) add_to_collection(name, value); } private void _check_not_finalized() { if (_finalized) throw new RuntimeError("Graph is finalized and cannot be modified."); } public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null) { if (inputs == null) inputs = new Tensor[0]; foreach ((int idx, Tensor a) in Python.enumerate(inputs)) { } if (String.IsNullOrEmpty(name)) name = op_type; // If a names ends with a '/' it is a "name scope" and we use it as-is, // after removing the trailing '/'. name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); var input_ops = inputs.Select(x => x.op).ToArray(); var control_inputs = _control_dependencies_for_inputs(input_ops); var op = new Operation(node_def, this, inputs: inputs, output_types: dtypes, control_inputs: control_inputs, input_types: input_types, original_op: null, op_def: op_def); _create_op_helper(op, true); /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); Console.WriteLine();*/ return op; } private void _create_op_helper(Operation op, bool compute_device = true) { _record_op_seen_by_control_dependencies(op); } public void _add_op(Operation op) { op._id_value = _next_id(); _nodes_by_id[op._id] = op; _nodes_by_name[op.name] = op; _version = Math.Max(_version, op._id); } public int _next_id() { return ++_next_id_counter; } public bool is_fetchable(T tensor_or_op) { if (tensor_or_op is Tensor) { return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; } else if (tensor_or_op is Operation) { return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); } return false; } public string get_name_scope() { return _name_stack; } public string name_scope(string name) { string new_stack = ""; if (string.IsNullOrEmpty(name)) new_stack = ""; else if (name.EndsWith("/")) new_stack = ops._name_from_scope_name(name); else new_stack = unique_name(name); _name_stack = new_stack; return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; } public string unique_name(string name, bool mark_as_used = true) { if (!String.IsNullOrEmpty(_name_stack)) { name = _name_stack + "/" + name; } var name_key = name.ToLower(); int i = 0; if (_names_in_use.ContainsKey(name_key)) { foreach (var item in _names_in_use) { if (item.Key == name_key) { i = _names_in_use[name_key]; break; } i++; } } if (mark_as_used) if (_names_in_use.ContainsKey(name_key)) _names_in_use[name_key]++; else _names_in_use[name_key] = i + 1; if (i > 0) { var base_name_key = name_key; // Make sure the composed name key is not already used. if (_names_in_use.ContainsKey(name_key)) { name_key = $"{base_name_key}_{i}"; i += 1; } if (mark_as_used) _names_in_use[name_key] = 1; name = $"{name}_{i - 1}"; } return name; } public TF_Output[] ReturnOutputs(IntPtr results) { IntPtr return_output_handle = IntPtr.Zero; int num_return_outputs = 0; c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); TF_Output[] return_outputs = new TF_Output[num_return_outputs]; for (int i = 0; i < num_return_outputs; i++) { var handle = return_output_handle + (Marshal.SizeOf() * i); return_outputs[i] = Marshal.PtrToStructure(handle); } return return_outputs; } public unsafe Operation[] ReturnOperations(IntPtr results) { TF_Operation return_oper_handle = new TF_Operation(); int num_return_opers = 0; c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); Operation[] return_opers = new Operation[num_return_opers]; for (int i = 0; i < num_return_opers; i++) { var handle = return_oper_handle.node + Marshal.SizeOf() * i; return_opers[i] = new Operation(*(IntPtr*)handle); } return return_opers; } public Operation OperationByName(string operName) { return c_api.TF_GraphOperationByName(_handle, operName); } public ITensorOrOperation[] get_operations() { return _nodes_by_name.Values.Select(x => x).ToArray(); } public string[] get_all_collection_keys() { return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); } public object get_collection(string name, string scope = "") { return _collections.ContainsKey(name) ? _collections[name] : null; } public object get_collection_ref(string name) { if (!_collections.ContainsKey(name)) _collections[name] = new List(); return _collections[name]; } public void Dispose() { c_api.TF_DeleteGraph(_handle); } public void __enter__() { } public void __exit__() { } public static implicit operator IntPtr(Graph graph) { return graph._handle; } } }