using Tensorflow.NumPy; using System; using Tensorflow.Eager; using Tensorflow.Variables; using Tensorflow.Train; using static Tensorflow.Binding; using System.Collections.Generic; using System.Diagnostics; using Tensorflow.Checkpoint; using Tensorflow.Training.Saving.SavedModel; using OneOf; using Tensorflow.Graphs; namespace Tensorflow { public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; public virtual string SharedName { get { // TODO(Rinne): optimize the implementation with refactor of variable. return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1); } } protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; public string handle_name { get { return _handle_name; } set { _handle_name = value; } } protected string _unique_id; public string UniqueId => _unique_id; protected bool _in_graph_mode; internal bool InGraphMode => _in_graph_mode; protected bool _trainable; public bool Trainable => _trainable; protected Tensor _initial_value; public Operation initializer => initializer_op; protected Tensor _parent_op; public Tensor parent_op => _parent_op; /// /// Tensor handle /// protected Tensor handle; public Tensor Handle => handle; protected Tensor _graph_element; public Tensor GraphElement => _graph_element; protected Shape _shape; public Shape shape => _shape; protected Operation initializer_op; public Operation Initializer => initializer_op; public Operation Op => handle.op; public Graph Graph => handle.graph; public string Device => handle.Device; EagerResourceDeleter eager_resource_deleter; public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; public BaseResourceVariable() { } public void __init__(bool trainable = true, Shape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, Tensor handle = null, string name = null, string unique_id = null, string handle_name = null) { _trainable = trainable; _handle_name = handle_name + ":0"; _unique_id = unique_id; this.handle = handle; _name = name; if(shape is not null) { _shape = shape; } if(dtype != TF_DataType.DtInvalid) { _dtype = dtype; } // After the handle has been created, set up a way to clean it up when // executing eagerly. We'll hold the only reference to the deleter, so that // when this object is garbage collected the deleter will be too. This // means ResourceVariables can be part of reference cycles without those // cycles being uncollectable. if (handle is EagerTensor) { _handle = handle.EagerTensorHandle.DangerousGetHandle(); // eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); } else if(handle is null) { // TODO: fix this dangerous change. _handle = IntPtr.Zero; } else { _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); } #if TRACK_TENSOR_LIFE print($"Created Resource 0x{_handle.ToString("x16")} {_name}"); #endif } public Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true) { if (value.GetType() == typeof(Tensor)) { var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name); if (read_value) return assign; return assign.op; } var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var assign_op = gen_resource_variable_ops.assign_variable_op( handle, value_tensor, name: name); if (read_value) return gen_resource_variable_ops.read_variable_op(handle, dtype); if (assign_op == null) return null; return assign_op; } public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice) { _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value); } void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0) { var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value, begin_mask: begin_mask, end_mask: end_mask, ellipsis_mask: ellipsis_mask, new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask); } public IVariableV1 assign_lazy_load(Tensor value, string name = null) { var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var assign_op = gen_resource_variable_ops.assign_variable_op( handle, value_tensor, name: name); var variable = _lazy_read(assign_op, value_tensor); return variable; } public Tensor value() => GraphElement ?? _read_variable_op(); protected Tensor _read_variable_op(bool no_copy = false) { variable_accessed(this); Tensor read_and_set_handle(bool no_copy) { if (no_copy) { gen_resource_variable_ops.disable_copy_on_read(handle); } var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); return result; } // TODO(Rinne): deal with caching device. var result = read_and_set_handle(no_copy); if (!tf.Context.executing_eagerly()) { tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle }, backward_function: (x, _) => x); } // have to set shape when converting to substituent placeholder if (result.shape.ndim == -1) { c_api.TF_GraphSetTensorShape(result.graph, result._as_tf_output(), shape.dims, shape.ndim, tf.Status); tf.Status.Check(true); } return result; } IVariableV1 _lazy_read(Operation op, Tensor value) { variable_accessed(this); return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id); } /// /// Records that `variable` was accessed for the tape and FuncGraph. /// void variable_accessed(BaseResourceVariable variable) { if(ops.get_default_graph() is FuncGraph func_graph) { func_graph.watch_variable(variable as IVariableV1); } if (variable.Trainable) { foreach (var tape in tf.GetTapeSet()) tape.VariableAccessed(variable as ResourceVariable); } } /// /// Constructs an op which reads the value of this variable. /// /// Should be used when there are multiple reads, or when it is desirable to /// read the value only after some condition is true. /// /// protected Tensor read_value() { var value = tf_with(ops.name_scope("Read"), delegate { return _read_variable_op(); }); return array_ops.identity(value); } public Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) { var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle, ops.convert_to_tensor(delta, dtype: dtype), name: name); if (read_value) return gen_resource_variable_ops.read_variable_op(handle, dtype); // return _lazy_read(assign_add_op); return assign_add_op; } public Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true) { var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, ops.convert_to_tensor(delta, dtype: dtype), name: name); if (read_value) return gen_resource_variable_ops.read_variable_op(handle, dtype); // return _lazy_read(assign_add_op); return assign_sub_op; } public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null) { var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle, ops.convert_to_tensor(delta, dtype: dtype), name: name); return _lazy_read(assign_sub_op, delta); } public override string ToString() { if (tf.Context.executing_eagerly()) return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}"; else return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}"; } public NDArray numpy() => read_value().numpy(); protected override void DisposeUnmanagedResources(IntPtr handle) { #if TRACK_TENSOR_LIFE print($"Deleted Resource 0x{handle.ToString("x16")} {_name}"); #endif } public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) { if (as_ref) return read_value().op.inputs[0]; else return value(); } public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) { BaseResourceVariable new_variable; if (save_options.experimental_variable_policy.save_variable_devices()) { Debug.Assert(this is ResourceVariable); new_variable = tf_with(ops.device(this.Device), _ => { return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); }); } else { new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); } Dictionary obj_map = new(); Dictionary resource_map = new(); obj_map[this] = new_variable; resource_map[this.handle] = new_variable.handle; return (obj_map, resource_map); } /// /// Writes additional information of the variable into the SavedObject proto. /// ubclasses of ResourceVariables could choose to override this method to /// customize extra information to provide when saving a SavedModel. /// /// /// public virtual void write_object_proto(SavedObject proto, SaveOptions options) { resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); } public override IDictionary>> gather_saveables_for_checkpoint() { var res = new Dictionary>>(); res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; return res; } public Tensor is_initialized(string name = null) { return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); } public Tensor read_value_no_copy() { Tensor value = null; tf_with(ops.name_scope("Read"), _ => { // TODO: `no_copy = true`. value = _read_variable_op(); }); return array_ops.identity(value); } //public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; //public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; //public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; //public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); //public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; //public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; //public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; //public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; //public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); //public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); //public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; //public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; //public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; //public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; } }