/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ using Google.Protobuf; using System; using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; namespace Tensorflow { public partial class RefVariable : VariableV1, IProtoBuf { public bool _in_graph_mode = true; public Tensor _initial_value; public string _graph_key; public bool _trainable; public Tensor _snapshot; public bool _save_slice_info; private Operation _initializer_op; public override Operation initializer => _initializer_op; public override Operation op => _variable.op; public TF_DataType dtype => _variable.dtype; public TensorShape shape => tensor_util.to_shape(_variable.shape); public override string name => _variable.name; public Tensor eval() => _variable; public RefVariable(object initial_value = null, bool trainable = true, List collections = null, bool validate_shape = true, string caching_device = "", string name = null, VariableDef variable_def = null, TF_DataType dtype = TF_DataType.DtInvalid, string import_scope = "") : base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) { _in_graph_mode = true; if (variable_def != null) { if (initial_value != null) throw new ValueError("variable_def and initial_value are mutually exclusive."); _init_from_proto(variable_def, import_scope: import_scope); } else { _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); } } private void _init_from_proto(VariableDef variable_def, string import_scope = "") { var g = ops.get_default_graph(); _variable = g.as_graph_element( ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope)) as Tensor; _initializer_op = g.as_graph_element( ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope)) as Operation; // Tests whether initial_value_name exists first for backwards compatibility. if (!string.IsNullOrEmpty(variable_def.InitialValueName)) _initial_value = g.as_graph_element( ops.prepend_name_scope(variable_def.InitialValueName, import_scope: import_scope)) as Tensor; else _initial_value = null; _trainable = variable_def.Trainable; _snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.SnapshotName, import_scope: import_scope)) as Tensor; if (variable_def.SaveSliceInfoDef != null) throw new NotImplementedException("save_slice_info_def"); else ;// _save_slice_info = null; //_caching_device = null; //_constraint = null; } private void _init_from_args(object initial_value, bool trainable = true, List collections = null, bool validate_shape = true, string caching_device = "", string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { if (initial_value is null) throw new ValueError("initial_value must be specified."); var init_from_fn = initial_value.GetType().Name == "Func`1"; if(collections == null) { collections = new List { tf.GraphKeys.GLOBAL_VARIABLES }; } // Store the graph key so optimizers know how to only retrieve variables from // this graph. _graph_key = ops.get_default_graph().graph_key; _trainable = trainable; if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); tf_with(ops.init_scope2(), delegate { var values = init_from_fn ? new object[0] : new object[] { initial_value }; tf_with(ops.name_scope(name, "Variable", values), scope => { name = scope; if (init_from_fn) { // Use attr_scope and device(None) to simulate the behavior of // colocate_with when the variable we want to colocate with doesn't // yet exist. string true_name = ops.name_from_scope_name(name); var attr = new AttrValue { List = new AttrValue.Types.ListValue() }; attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); tf_with(ops.name_scope("Initializer"), scope2 => { _initial_value = (initial_value as Func)(); _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); }); _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); } // Or get the initial value from a Tensor or Python object. else { _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); var shape = _initial_value.shape; dtype = _initial_value.dtype; _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); } // Manually overrides the variable's shape with the initial value's. if (validate_shape) { var initial_value_shape = _initial_value.TensorShape; if (!initial_value_shape.is_fully_defined()) throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); } // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; if (!String.IsNullOrEmpty(caching_device)) { } else { ops.colocate_with(_initializer_op); _snapshot = gen_array_ops.identity(_variable, name = "read"); } ops.add_to_collections(collections, this as VariableV1); }); }); } public Tensor _ref() => _variable; public Tensor value() => _snapshot; public Tensor _AsTensor() => _snapshot; public Tensor _as_graph_element() => _variable; public Tensor _TensorConversionFunction(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) { if (as_ref) return _ref(); else return value(); } /// /// Attempt to guard against dependencies on uninitialized variables. /// /// private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) { return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary()); } /// /// Replace dependencies on variables with their initialized values. /// /// A `Tensor`. The tensor to replace. /// A dict mapping operation names to `Operation`s. /// A `Tensor` compatible with `tensor`. private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) { var op = tensor.op; var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; if(new_op == null) { new_op = _safe_initial_value_from_op(name, op, op_cache); op_cache[op.name] = new_op; } return new_op.outputs[tensor.value_index]; } private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) { var op_type = op.node_def.Op; switch (op_type) { case "IsVariableInitialized": case "VarIsInitializedOp": case "ReadVariableOp": return op; case "Variable": case "VariableV2": case "VarHandleOp": var initialized_value = _find_initialized_value_for_variable(op); return initialized_value == null ? op : initialized_value.op; } // Recursively build initializer expressions for inputs. var modified = false; var new_op_inputs = new List(); foreach (var op_input in op.inputs) { var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache); new_op_inputs.Add(new_op_input); modified = modified || new_op_input != op_input; } // If at least one input was modified, replace the op. if (modified) { var new_op_type = op_type; if (new_op_type == "RefSwitch") new_op_type = "Switch"; var new_op_name = op.node_def.Name + "_" + name; new_op_name = new_op_name.Replace(":", "_"); // Convert attr values to AttrValue protos. var attr_protos = new Dictionary(); foreach (var attr_def in op.node_def.Attr) attr_protos[attr_def.Key] = attr_def.Value; return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, name: new_op_name, attrs: attr_protos); } return op; } private Operation _find_initialized_value_for_variable(Operation variable_op) { var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" }; foreach(var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES }) { foreach (var var in variable_op.graph.get_collection(collection_name)) if (var_names.Contains(var.name)) return var.initialized_value(); } return null; } /// /// Assigns a new value to the variable. /// /// The new value for this variable. /// If `True`, use locking during the assignment. /// The name of the operation to be created /// /// if True, will return something which evaluates to the /// new value of the variable; if False will return the assign op. /// /// /// A `Tensor` that will hold the new value of this variable after /// the assignment has completed. /// public ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true) { var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); if (read_value) return assign; return assign.op; } public override string ToString() { return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}"; } public VariableDef to_proto(string export_scope) { if(string.IsNullOrEmpty(export_scope) || _variable.name.StartsWith(export_scope)) { var var_def = new VariableDef(); var_def.VariableName = ops.strip_name_scope(_variable.name, export_scope); if (_initial_value != null) var_def.InitialValueName = ops.strip_name_scope(_initial_value.name, export_scope); var_def.Trainable = _trainable; var_def.InitializerName = ops.strip_name_scope(initializer.name, export_scope); var_def.SnapshotName = ops.strip_name_scope(_snapshot.name, export_scope); if (_save_slice_info) throw new NotImplementedException("to_proto _save_slice_info"); return var_def; } throw new NotImplementedException("to_proto RefVariable"); } public RefVariable from_proto(VariableDef proto, string import_scope) { throw new NotImplementedException(); } /// /// Returns the value of this variable, read in the current context. /// /// private ITensorOrOperation read_value() { return array_ops.identity(_variable, name: "read"); } /// /// Returns the Tensor used as the initial value for the variable. /// /// private ITensorOrOperation initial_value() { return _initial_value; } public Tensor is_variable_initialized(RefVariable variable) { return state_ops.is_variable_initialized(variable); } public Tensor initialized_value() { ops.init_scope(); return control_flow_ops.cond(is_variable_initialized(this), read_value, initial_value); } } }