From 93eb56e5a3cf734f69faf272fe4da3bc6756fec0 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 21 Jun 2019 00:00:00 -0500 Subject: [PATCH] add apply_adam, _apply_dense for Adam. #271 --- .../_InitializeClustersOpFactory.cs | 12 +- .../Gradients/array_grad.cs | 2 +- .../Gradients/gradients_util.cs | 10 +- src/TensorFlowNET.Core/Gradients/nn_grad.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 184 +++++++++--------- src/TensorFlowNET.Core/Train/AdamOptimizer.cs | 21 +- src/TensorFlowNET.Core/Train/Optimizer.cs | 2 +- .../Train/gen_training_ops.py.cs | 23 +++ 8 files changed, 146 insertions(+), 110 deletions(-) diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs index 1b985bf9..0e437d51 100644 --- a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs +++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs @@ -47,10 +47,10 @@ namespace Tensorflow.Clustering _cluster_centers_updated = cluster_centers_updated; _cluster_centers_initialized = cluster_centers_initialized; - _num_selected = array_ops.shape(_cluster_centers)[0]; + _num_selected = array_ops.shape(_cluster_centers).slice(0); _num_remaining = _num_clusters - _num_selected; - _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); + _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i).slice(0)).ToArray()); } private Tensor _initialize() @@ -68,7 +68,7 @@ namespace Tensorflow.Clustering }, () => { - return control_flow_ops.no_op().output[0]; + return control_flow_ops.no_op().output.slice(0); }); }); } @@ -90,7 +90,7 @@ namespace Tensorflow.Clustering // Adds some centers and returns the number of centers remaining. var new_centers = _choose_initial_centers(); if (_distance_metric == KMeans.COSINE_DISTANCE) - new_centers = nn_impl.l2_normalize(new_centers[0], axis: 1); + new_centers = nn_impl.l2_normalize(new_centers.slice(0), axis: 1); // If cluster_centers is empty, it doesn't have the right shape for concat. var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), @@ -99,12 +99,12 @@ namespace Tensorflow.Clustering var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); - return _num_clusters - array_ops.shape(a)[0]; + return _num_clusters - array_ops.shape(a).slice(0); } private Tensor _choose_initial_centers() { - return _greedy_batch_sampler()[0]; + return _greedy_batch_sampler().slice(0); } private Tensor _greedy_batch_sampler() diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 4896d4dd..eec3521b 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -156,7 +156,7 @@ namespace Tensorflow.Gradients // For axis 0 gathers, build an appropriately shaped IndexedSlices. if((int)axis_static == 0) { - var params_tail_shape = params_shape[new NumSharp.Slice(start:1)]; + var params_tail_shape = params_shape.slice(new NumSharp.Slice(start:1)); var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); var values = array_ops.reshape(grad, values_shape); indices = array_ops.reshape(indices, indices_size); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 12a50479..c2d54057 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -105,16 +105,16 @@ namespace Tensorflow var has_out_grads = true; if (has_out_grads && !stop_ops.Contains(op)) { + // A grad_fn must be defined, either as a function or as None + // for ops that do not have gradients. + var grad_fn = ops.get_gradient_function(op); + if (is_func_call) { } else { - // A grad_fn must be defined, either as a function or as None - // for ops that do not have gradients. - var grad_fn = ops.get_gradient_function(op); - foreach (var (i, out_grad) in enumerate(out_grads)) { if (out_grad == null) @@ -322,7 +322,7 @@ namespace Tensorflow else { used = "add_n"; - out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; + return_grads[i] = _MultiDeviceAddN(out_grad.ToArray(), gradient_uid); } } else diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 7ef39cde..71a833ad 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -200,7 +200,7 @@ namespace Tensorflow.Gradients var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64), array_ops.size(in_shape) - 1); - var outerdim = array_ops.shape(ind_2d)[0]; + var outerdim = array_ops.shape(ind_2d).slice(0); // Compute linear indices(flattened to 1D). var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index f782451f..9571287d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -224,116 +224,110 @@ namespace Tensorflow } } - public Tensor this[Slice slice] + public Tensor slice(Slice slice) { - get - { - var slice_spec = new int[] { slice.Start.Value }; - var begin = new List(); - var end = new List(); - var strides = new List(); + var slice_spec = new int[] { slice.Start.Value }; + var begin = new List(); + var end = new List(); + var strides = new List(); - var index = 0; - var (new_axis_mask, shrink_axis_mask) = (0, 0); - var (begin_mask, end_mask) = (0, 0); - var ellipsis_mask = 0; + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; - foreach (var s in slice_spec) + foreach (var s in slice_spec) + { + begin.Add(s); + if (slice.Stop.HasValue) { - begin.Add(s); - if(slice.Stop.HasValue) - { - end.Add(slice.Stop.Value); - } - else - { - end.Add(0); - end_mask |= (1 << index); - } - strides.Add(slice.Step); - - index += 1; + end.Add(slice.Stop.Value); } - - return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + else { - string name = scope; - if (begin != null) - { - var (packed_begin, packed_end, packed_strides) = - (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); - - return gen_array_ops.strided_slice( - this, - packed_begin, - packed_end, - packed_strides, - begin_mask: begin_mask, - end_mask: end_mask, - shrink_axis_mask: shrink_axis_mask, - new_axis_mask: new_axis_mask, - ellipsis_mask: ellipsis_mask, - - name: name); - } - - throw new NotImplementedException(""); - }); + end.Add(0); + end_mask |= (1 << index); + } + strides.Add(slice.Step); + + index += 1; } + + return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + + name: name); + } + + throw new NotImplementedException(""); + }); } - public Tensor this[int start] + public Tensor slice(int start) { - get - { - var slice_spec = new int[] { start }; - var begin = new List(); - var end = new List(); - var strides = new List(); + var slice_spec = new int[] { start }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; - var index = 0; - var (new_axis_mask, shrink_axis_mask) = (0, 0); - var (begin_mask, end_mask) = (0, 0); - var ellipsis_mask = 0; + foreach (var s in slice_spec) + { + begin.Add(s); + end.Add(s + 1); + strides.Add(1); + shrink_axis_mask |= (1 << index); + index += 1; + } - foreach (var s in slice_spec) + return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) { - begin.Add(s); - end.Add(s + 1); - strides.Add(1); - shrink_axis_mask |= (1 << index); - index += 1; + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + + name: name); } - return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => - { - string name = scope; - if (begin != null) - { - var (packed_begin, packed_end, packed_strides) = - (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); - - return gen_array_ops.strided_slice( - this, - packed_begin, - packed_end, - packed_strides, - begin_mask: begin_mask, - end_mask: end_mask, - shrink_axis_mask: shrink_axis_mask, - new_axis_mask: new_axis_mask, - ellipsis_mask: ellipsis_mask, - - name: name); - } - - throw new NotImplementedException(""); - }); - } + throw new NotImplementedException(""); + }); } public override string ToString() diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 4a801b32..1086828a 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -16,7 +16,7 @@ namespace Tensorflow.Train float _beta1; float _beta2; float _epsilon; - Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t; + Tensor _beta1_t, _beta2_t, _epsilon_t; public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") : base(learning_rate, use_locking, name) @@ -34,6 +34,25 @@ namespace Tensorflow.Train }); } + public override Operation _apply_dense(Tensor grad, RefVariable var) + { + var m = get_slot(var, "m"); + var v = get_slot(var, "v"); + var (beta1_power, beta2_power) = _get_beta_accumulators(); + return gen_training_ops.apply_adam( + var, + m, + v, + math_ops.cast(beta1_power, var.dtype.as_base_dtype()), + math_ops.cast(beta2_power, var.dtype.as_base_dtype()), + math_ops.cast(_lr_t, var.dtype.as_base_dtype()), + math_ops.cast(_beta1_t, var.dtype.as_base_dtype()), + math_ops.cast(_beta2_t, var.dtype.as_base_dtype()), + math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()), + grad, + use_locking: _use_locking).op; + } + private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func scatter_add) { var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index c7a31b9d..0ce5225f 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -272,7 +272,7 @@ namespace Tensorflow public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) { var (unique_indices, new_index_positions) = array_ops.unique(indices); - var shape = array_ops.shape(unique_indices)[0]; + var shape = array_ops.shape(unique_indices).slice(0); var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape); return (summed_values, unique_indices); } diff --git a/src/TensorFlowNET.Core/Train/gen_training_ops.py.cs b/src/TensorFlowNET.Core/Train/gen_training_ops.py.cs index 0f1fb271..53726c3f 100644 --- a/src/TensorFlowNET.Core/Train/gen_training_ops.py.cs +++ b/src/TensorFlowNET.Core/Train/gen_training_ops.py.cs @@ -8,6 +8,29 @@ namespace Tensorflow { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Tensor apply_adam(RefVariable var, RefVariable m, RefVariable v, Tensor beta1_power, Tensor beta2_power, + Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, + bool use_locking = false, bool use_nesterov = false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ApplyAdam", name, new + { + var, + m, + v, + beta1_power, + beta2_power, + lr, + beta1, + beta2, + epsilon, + grad, + use_locking, + use_nesterov + }); + + return _op.outputs[0]; + } + public static Tensor apply_gradient_descent(RefVariable var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) { var _op = _op_def_lib._apply_op_helper("ApplyGradientDescent", name, new