using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace Tensorflow { public partial class RefVariable : VariableV1 { public bool _in_graph_mode = true; public Tensor _initial_value; public string _graph_key; public bool _trainable; public Tensor _variable; public Tensor _snapshot; private Operation _initializer_op; public Operation initializer => _initializer_op; public Operation op => _variable.op; public TF_DataType dtype => _variable.dtype; public TensorShape shape => tensor_util.to_shape(_variable.shape); public string name => _variable.name; public RefVariable(object initial_value, bool trainable = true, List collections = null, bool validate_shape = true, string caching_device = "", string name = "", TF_DataType dtype = TF_DataType.DtInvalid) : base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) { _in_graph_mode = true; _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); } private void _init_from_args(object initial_value, bool trainable = true, List collections = null, bool validate_shape = true, string caching_device = "", string name = "", TF_DataType dtype = TF_DataType.DtInvalid) { if (initial_value is null) throw new ValueError("initial_value must be specified."); var init_from_fn = false; if(collections == null) { collections = new List { ops.GraphKeys.GLOBAL_VARIABLES }; } // Store the graph key so optimizers know how to only retrieve variables from // this graph. _graph_key = ops.get_default_graph()._graph_key; _trainable = trainable; if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); var values = init_from_fn ? new List() : new List { initial_value }; using (var namescope = new ops.name_scope(name, "Variable", values)) { name = namescope; if (init_from_fn) { } // Or get the initial value from a Tensor or Python object. else { _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); var shape = _initial_value.shape; dtype = _initial_value.dtype; _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name); } // Manually overrides the variable's shape with the initial value's. if (validate_shape) { var initial_value_shape = _initial_value.shape; } // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; if (!String.IsNullOrEmpty(caching_device)) { } else { ops.colocate_with(_initializer_op); _snapshot = gen_array_ops.identity(_variable, name = "read"); } ops.add_to_collections(collections, this); } } public Tensor _ref() { return _variable; } public Tensor _AsTensor() { return _snapshot; } /// /// Attempt to guard against dependencies on uninitialized variables. /// /// private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) { return _safe_initial_value_from_tensor(initial_value, new Dictionary()); } /// /// Replace dependencies on variables with their initialized values. /// /// A `Tensor`. The tensor to replace. /// A dict mapping operation names to `Operation`s. /// A `Tensor` compatible with `tensor`. private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary op_cache) { var op = tensor.op; var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null; if(new_op == null) { new_op = _safe_initial_value_from_op(op, op_cache); op_cache[op.Name] = new_op; } return new_op.outputs[tensor.value_index]; } private Operation _safe_initial_value_from_op(Operation op, Dictionary op_cache) { var op_type = op.node_def.Op; switch (op_type) { case "IsVariableInitialized": case "VarIsInitializedOp": case "ReadVariableOp": return op; case "Variable": case "VariableV2": case "VarHandleOp": break; } // Recursively build initializer expressions for inputs. return op; } } }