diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs index 516d0163..ef422968 100644 --- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs +++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs @@ -11,10 +11,26 @@ namespace Tensorflow.Framework { Tensor _values; public Tensor values => _values; + Tensor _indices; + public Tensor indices => _indices; + Tensor _dense_shape; + public Tensor dense_shape => _dense_shape; + + public string name => _values.name; + + public string device => _values.Device; + + public Operation op => _values.op; + + public TF_DataType dtype => _values.dtype; + + public Graph graph => _values.graph; public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) { - + _values = values; + _indices = indices; + _dense_shape = dense_shape; } public static implicit operator Tensor(IndexedSlices indexedSlices) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index aa74f7f1..58dd7e4a 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -83,13 +83,13 @@ namespace Tensorflow.Gradients new Tensor[] { non_neg_concat_dim, tf.constant(0) }, new Tensor[] { tf.constant(1), tf.constant(-1) }); var squeeze_sizes = array_ops.squeeze(slice); - out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); + out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); } else { - var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes); + var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); foreach (var (begin, size) in zip(offset, sizes)) - out_grads.Add(gen_ops.slice(grad, begin, size)); + out_grads.Add(gen_array_ops.slice(grad, begin, size)); } return (end_value_index <= dim_index ? @@ -129,7 +129,7 @@ namespace Tensorflow.Gradients if (fully_known) return sizes; else - return gen_ops.shape_n(inputs); + return gen_array_ops.shape_n(inputs); } /// diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index c68fdfa3..12a50479 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -93,10 +93,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(op.name == "embedding/ExpandDims") - { - } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); @@ -311,9 +308,10 @@ namespace Tensorflow // Aggregate multiple gradients, and convert [] to None. if (out_grad.Count > 0) { + string used = ""; if (out_grad.Count < 2) { - string used = "nop"; + used = "nop"; if (out_grad.Count == 0) { throw new ValueError("_AggregatedGrads out_grad.Length == 0"); @@ -321,6 +319,11 @@ namespace Tensorflow return_grads[i] = out_grad[0]; } + else + { + used = "add_n"; + out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; + } } else { @@ -331,6 +334,38 @@ namespace Tensorflow return return_grads; } + /// + /// Adds tensors from potentially multiple devices. + /// + /// + /// + /// + private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid) + { + // Basic function structure comes from control_flow_ops.group(). + // Sort tensors according to their devices. + var tensors_on_device = new Dictionary>(); + + foreach (var tensor in tensor_list) + { + if (!tensors_on_device.ContainsKey(tensor.Device)) + tensors_on_device[tensor.Device] = new List(); + + tensors_on_device[tensor.Device].Add(tensor); + } + + // For each device, add the tensors on that device first. + var summands = new List(); + foreach(var dev in tensors_on_device.Keys) + { + var tensors = tensors_on_device[dev]; + ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); + summands.Add(math_ops.add_n(tensors.ToArray())); + } + + return math_ops.add_n(summands.ToArray()); + } + /// /// The set of ops that terminate the gradient computation. /// diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 4a6ee82d..c997f179 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -276,6 +276,9 @@ namespace Tensorflow }); } + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + => gen_array_ops.unique(x, out_idx: out_idx, name: name); + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) { if( x == null && y == null) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index fb980259..087a2430 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -26,6 +26,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ConcatOffset", name: name, args: new { concat_dim, shape }); + + return _op.outputs; + } + /// /// Returns a diagonal tensor with a given diagonal values. /// @@ -205,6 +212,21 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Finds unique elements in a 1-D tensor. + /// + /// + /// + /// + /// + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); + // TODO + throw new NotImplementedException("_result = _UniqueOutput._make(_result)"); + // return _op.outputs[0]; + } + public static Tensor where() { throw new NotImplementedException("where"); @@ -271,6 +293,26 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Return a slice from 'input' + /// + /// + /// + /// + /// + /// + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); + return _op.outputs[0]; + } + + public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); + return _op.outputs; + } + public static Tensor tile(Tensor input, Tensor multiples, string name = null) { var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e5670dd0..763a4bd8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -16,6 +16,19 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Add all input tensors element wise. + /// + /// + /// + /// + public static Tensor add_n(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); + + return _op.outputs[0]; + } + /// /// Returns the index with the largest value across dimensions of a tensor. /// @@ -198,6 +211,20 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) + { + var _op = _op_def_lib._apply_op_helper("UnsortedSegmentSum", name, new { data, segment_ids, num_segments }); + return _op.outputs[0]; + } + public static Tensor tan(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 580ae33c..29e9d671 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -44,8 +44,8 @@ namespace Tensorflow return array_ops.identity(values, name: name); return values; } - throw new NotImplementedException("math_ops add_n n > 1"); - // return gen_math_ops.add_n(inputs, name: name); + + return gen_math_ops.add_n(inputs, name: name); } public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) @@ -126,6 +126,9 @@ namespace Tensorflow public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name); + public static Tensor sqrt(Tensor x, string name = null) + => gen_math_ops.sqrt(x, name: name); + public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); @@ -319,6 +322,17 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, min); } + /// + /// Computes the sum along segments of a tensor. + /// + /// + /// + /// + /// + /// + public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) + => gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name); + /// /// Casts a tensor to type `int32`. /// diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index cecdbd38..63f440c1 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,10 +5,10 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.8.1 + 0.8.2 Haiping Chen SciSharp STACK - true + false Apache 2.0 https://github.com/SciSharp/TensorFlow.NET git @@ -17,14 +17,15 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.8.1.0 + 0.8.2.0 Changes since v0.8: 1. Remove global static graph instance. 2. Provide custom gradient function. -3. Add gradient function for Conv2D. +3. Add gradient function for Conv2D. +4. Fix bug for Transfer Learning example. 7.2 - 0.8.1.0 + 0.8.2.0 @@ -42,6 +43,10 @@ Docs: https://tensorflownet.readthedocs.io + + + + diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index b6063234..56e69881 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Python; namespace Tensorflow.Train { @@ -10,9 +12,10 @@ namespace Tensorflow.Train /// public class AdamOptimizer : Optimizer { - private float _beta1; - private float _beta2; - private float _epsilon; + float _beta1; + float _beta2; + float _epsilon; + Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t; public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") : base(learning_rate, use_locking, name) @@ -21,5 +24,51 @@ namespace Tensorflow.Train _beta2 = beta2; _epsilon = epsilon; } + + public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => + { + return state_ops.scatter_add(x, i, v, use_locking: _use_locking); + }); + } + + private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func scatter_add) + { + var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); + Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype()); + Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype()); + var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype()); + var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype()); + var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype()); + var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()); + var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); + var m = get_slot(var, "m"); + var m_scaled_g_values = grad * (1 - beta1_t); + var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); + with(ops.control_dependencies(new[] { m_t }), delegate + { + m_t = scatter_add(m, indices, m_scaled_g_values); + }); + + var v = get_slot(var, "v"); + var v_scaled_g_values = (grad * grad) * (1 - beta2_t); + var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking); + with(ops.control_dependencies(new[] { v_t }), delegate + { + v_t = scatter_add(v, indices, v_scaled_g_values); + }); + var v_sqrt = math_ops.sqrt(v_t); + var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking); + return control_flow_ops.group(new[] { var_update, m_t, v_t }); + } + + private (RefVariable, RefVariable) _get_beta_accumulators() + { + ops.init_scope(); + var graph = ops.get_default_graph(); + return (_get_non_slot_variable("beta1_power", graph: graph), + _get_non_slot_variable("beta2_power", graph: graph)); + } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 3a14390d..f5474c23 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Framework; using static Tensorflow.Python; namespace Tensorflow @@ -23,8 +24,8 @@ namespace Tensorflow public float LearningRate { get; set; } public Tensor LearningRateTensor { get; set; } public bool _use_locking; - public Dictionary _slots; - public Dictionary _non_slot_dict; + public Dictionary> _slots; + public Dictionary _non_slot_dict; public Dictionary _deferred_slot_restorations; public Optimizer(float learning_rate, bool use_locking, string name = null) @@ -36,8 +37,8 @@ namespace Tensorflow _use_locking = use_locking; LearningRate = learning_rate; // Dictionary of slots. - _slots = new Dictionary(); - _non_slot_dict = new Dictionary(); + _slots = new Dictionary>(); + _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -110,7 +111,7 @@ namespace Tensorflow public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null) { // No DistributionStrategy case. - var converted_grads_and_vars = new List>(); + var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); foreach (var (g, v) in grads_and_vars) { if(g != null) @@ -118,7 +119,7 @@ namespace Tensorflow // Convert the grad to Tensor or IndexedSlices if necessary. var gR = ops.convert_to_tensor_or_indexed_slices(g); var p = _get_processor(v); - converted_grads_and_vars.Add(new Tuple(gR, v, p)); + converted_grads_and_vars.Add((gR, v, p)); } } @@ -143,7 +144,8 @@ namespace Tensorflow var scope_name = var.op.name; with(ops.name_scope("update_" + scope_name), scope2 => { - update_ops.Add(processor.update_op(this, grad)); + var op = processor.update_op(this, grad); + update_ops.Add(op); }); } @@ -201,11 +203,69 @@ namespace Tensorflow return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; } + /// + /// Add ops to apply sparse gradients to `var`, with repeated sparse indices. + /// + /// + /// + /// + public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var) + { + var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); + var gradient_no_duplicate_indices = new IndexedSlices( + indices: unique_indices, + values: summed_values, + dense_shape: grad.dense_shape); + return _apply_sparse(gradient_no_duplicate_indices, var); + } + + public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) + { + throw new NotImplementedException("_apply_sparse"); + } + + public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) + { + var (unique_indices, new_index_positions) = array_ops.unique(indices); + var summed_values = math_ops.unsorted_segment_sum( + values, new_index_positions, + array_ops.shape(unique_indices)[0]); + return (summed_values, unique_indices); + } + public virtual void _prepare() { } + /// + /// Return a slot named `name` created for `var` by the Optimizer. + /// + /// + /// + /// + protected RefVariable get_slot(RefVariable var, string name) + { + var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; + if (named_slots == null) + return null; + + return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; + } + + private string _var_key(RefVariable var) + { + return $"{var.op.graph.graph_key}.{var.op.name}"; + } + + protected RefVariable _get_non_slot_variable(string name, Graph graph = null) + { + var key = $"{graph.graph_key}.{name}"; + var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + + return non_slot; + } + private _OptimizableVariable _get_processor(RefVariable v) { if(v is RefVariable) diff --git a/src/TensorFlowNET.Core/Train/optimizer.py.cs b/src/TensorFlowNET.Core/Train/optimizer.py.cs index 3a376e97..fbf32876 100644 --- a/src/TensorFlowNET.Core/Train/optimizer.py.cs +++ b/src/TensorFlowNET.Core/Train/optimizer.py.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Framework; namespace Tensorflow { @@ -32,5 +33,12 @@ namespace Tensorflow return update_op; } + + public Operation update_op(Optimizer optimizer, IndexedSlices g) + { + var update_op = optimizer._apply_dense(g, _v); + + return update_op; + } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 4b4237a0..a5a4ab69 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -97,6 +97,20 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); return _op.outputs[0]; } - + + /// + /// Adds sparse updates to a variable reference. + /// + /// + /// + /// + /// + /// + /// + public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index aaa27e85..4022e1dc 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -72,5 +72,13 @@ namespace Tensorflow Tensor value, bool use_locking = false, string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + + public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); + + throw new NotImplementedException("scatter_add"); + } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 60b6d050..465b08b2 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -8,10 +8,7 @@ using System.Text; using Newtonsoft.Json; using NumSharp; using Tensorflow; -using Tensorflow.Keras.Engine; using Tensorflow.Sessions; -using TensorFlowNET.Examples.Text.cnn_models; -using TensorFlowNET.Examples.TextClassification; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs index 4b4623dc..744e52c3 100644 --- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs +++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.nn_test public void testZeroFraction() { var x_shape = new Shape(5, 17); - var x_np = new NumPyRandom().randint(0, 2, x_shape); + var x_np = np.random.randint(0, 2, x_shape); x_np.astype(np.float32); var y_np = this._ZeroFraction(x_np);