diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 399578cf..4e5b0d89 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -40,6 +40,8 @@ namespace Tensorflow.Gradients return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad }; var concat_dim = op.inputs[dim_index]; + if (end_value_index == -1) + end_value_index = op.inputs.Length - 1; var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray(); var out_grads = new List(); diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 0d1e6c8a..18151ac5 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -43,12 +43,6 @@ namespace Tensorflow if (grad_ys == null) grad_ys = new Tensor[ys.Length]; - var all = new List(); - all.AddRange(ys); - all.AddRange(xs); - all.AddRange(stop_gradients); - all.AddRange(grad_ys); - // Iterate over the collected ops. /** * grads: op => list of gradients received on each output endpoint of the @@ -59,7 +53,8 @@ namespace Tensorflow **/ var grads = new Dictionary(); - with(ops.name_scope(name, "gradients", values: all), scope => + with(ops.name_scope(name, "gradients", + values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => { string grad_scope = scope; // Get a uid for this call to gradients that can be used to help @@ -166,7 +161,7 @@ namespace Tensorflow } var inputs = _NonEagerInputs(op, xs).ToList(); - foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) + foreach (var (t_in, in_grad) in zip(inputs, in_grads)) { if(in_grad != null) { diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index fccc924b..5b33e655 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -72,7 +72,7 @@ namespace Tensorflow private string _graph_key; public string graph_key => _graph_key; public string _last_loss_reduction; - + public bool _is_loss_scaled_by_optimizer { get; set; } public Status Status { get; } /// diff --git a/src/TensorFlowNET.Core/Operations/InputList.cs b/src/TensorFlowNET.Core/Operations/InputList.cs index 9abe7303..f442e546 100644 --- a/src/TensorFlowNET.Core/Operations/InputList.cs +++ b/src/TensorFlowNET.Core/Operations/InputList.cs @@ -10,7 +10,15 @@ namespace Tensorflow { public Tensor[] _inputs; public int Length => _inputs.Length; - public Tensor this[int index] => _inputs[index]; + public Tensor this[int index] + { + get + { + if (index == -1) + index = _inputs.Length - 1; + return _inputs[index]; + } + } public InputList(Tensor[] inputs) { diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index bafa254f..86c53286 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -56,6 +56,7 @@ Removed global static graph instance. + diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index f068f910..3a14390d 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using distribute_lib = Tensorflow.Distribute; using static Tensorflow.Python; namespace Tensorflow @@ -82,7 +81,8 @@ namespace Tensorflow var grads_and_vars = compute_gradients(loss, var_list:var_list, gate_gradients: gate_gradients, aggregation_method:aggregation_method, - colocate_gradients_with_ops: colocate_gradients_with_ops); + colocate_gradients_with_ops: colocate_gradients_with_ops, + grad_loss: grad_loss); var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray(); if (vars_with_grad.Length == 0) @@ -232,30 +232,31 @@ namespace Tensorflow int? aggregation_method = null, GateGradientType gate_gradients = GateGradientType.GATE_OP, bool colocate_gradients_with_ops = false, - Tensor[] grad_loss = null) + Tensor grad_loss = null) { + // Scale loss if using a "mean" loss reduction and multiple replicas. + loss = _scale_loss(loss); int num_towers = 1; - if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) - { - - } + var tmp = variables.trainable_variables(); + var vars = ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); switch (tmp) { case List values: - var_list = values; + var_list = values.Concat(vars).ToList(); break; case List values: - var_list = values.Select(x => x as RefVariable).ToList(); + var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); break; } + var_list = var_list.Concat(ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var var_refs = processors.Select(x => x.target()).ToArray(); - var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss, - gate_gradients: (gate_gradients == GateGradientType.GATE_OP), + var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss }, + gate_gradients: gate_gradients == GateGradientType.GATE_OP, aggregation_method: aggregation_method, colocate_gradients_with_ops: colocate_gradients_with_ops); @@ -269,6 +270,14 @@ namespace Tensorflow return grads_and_vars; } + private Tensor _scale_loss(Tensor loss_value) + { + ops.get_default_graph()._is_loss_scaled_by_optimizer = false; + // TODO + // if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: + return loss_value; + } + protected T _call_if_callable(T param) { return param; diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index bb3b7188..84f48db5 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -22,6 +22,16 @@ namespace Tensorflow /// public static string TRAINABLE_VARIABLES = "trainable_variables"; + /// + /// Trainable resource-style variables. + /// + public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; + + /// + /// Key for streaming model ports. + /// + public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; + /// /// Key to collect losses /// diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 8faf8841..20dbf668 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -45,6 +45,11 @@ namespace Tensorflow return get_default_graph().get_collection(key, scope); } + public static List get_collection(string key, string scope = null) + { + return get_default_graph().get_collection(key, scope); + } + public static object get_collection_ref(string key) { return get_default_graph().get_collection_ref(key); diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 0e7e29e0..b3ed9bb2 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -207,7 +207,7 @@ namespace TensorFlowNET.Examples Tensor predictions = null; with(tf.name_scope("output"), delegate { - logits = tf.layers.dense(h_pool_flat, keep_prob); + logits = tf.layers.dense(h_pool_flat, NUM_CLASS); predictions = tf.argmax(logits, -1, output_type: tf.int32); }); @@ -215,7 +215,8 @@ namespace TensorFlowNET.Examples { var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y); var loss = tf.reduce_mean(sscel); - var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step); + var adam = tf.train.AdamOptimizer(learning_rate); + var optimizer = adam.minimize(loss, global_step: global_step); }); with(tf.name_scope("accuracy"), delegate