diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index e292c743..6ed475a9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -54,7 +54,7 @@ namespace Tensorflow maximum_iterations: maximum_iterations, return_same_structure: return_same_structure); - public _ControlDependenciesController control_dependencies(Operation[] control_inputs) + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) => ops.control_dependencies(control_inputs); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index c91f283d..571d57b2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -36,8 +36,8 @@ namespace Tensorflow public void device(string device_name) => get_default_graph().device(device_name); - public object get_collection(string key, string scope = "") - => get_default_graph().get_collection(key, scope: scope); + public List get_collection(string key, string scope = "") + => get_default_graph().get_collection(key, scope: scope); /// /// A context manager that lifts ops out of control-flow scopes and function-building graphs. @@ -60,7 +60,7 @@ namespace Tensorflow /// /// /// - public Tensor no_op(string name = null) + public Operation no_op(string name = null) => gen_control_flow_ops.no_op(name: name); /// diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index d7d7ef7e..aaa5d225 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -180,7 +180,7 @@ namespace Tensorflow var graph = ops.get_default_graph(); var var_list = new Dictionary(); - var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List; + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); if (variables != null) { diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs similarity index 93% rename from src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs rename to src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index 22a73374..670731e0 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -1,28 +1,30 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. ******************************************************************************/ using System; using System.Linq; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow.Gradients { /// /// Gradients for operators defined in control_flow_ops.py.cs /// + [RegisterGradient("control_flow_grad")] public class control_flow_grad { /// @@ -33,6 +35,7 @@ namespace Tensorflow.Gradients /// on the second visit. A next_iteration is also added on second visit. /// /// + [RegisterGradient("Switch")] public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) { throw new NotImplementedException("_SwitchGrad"); @@ -83,68 +86,68 @@ namespace Tensorflow.Gradients // false_grad = switch(grad[0], op.inputs[1])[0] // true_grad = switch(grad[1], op.inputs[1])[1] // return merge([false_grad, true_grad])[0], None - } - + } + /// /// Gradients for a Merge op are calculated using a Switch op. /// - [RegisterGradient("Merge")] + [RegisterGradient("Merge")] public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) { var grad = grads[0]; - var _ = grads[1]; var input_op = op.inputs[0].op; var graph = ops.get_default_graph(); var op_ctxt = control_flow_util.GetOutputContext(input_op); var grad_ctxt = graph._get_control_flow_context(); switch (op_ctxt) { - case WhileContext cwhile: - { - return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); + case WhileContext cwhile: + { + return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); } - case CondContext ccond: - { - var pred = ccond.pred; - if (grad_ctxt != null && grad_ctxt.grad_state != null) - { - //# This Merge node is part of a cond within a loop. - //# The backprop needs to have the value of this predicate for every - //# iteration. So we must have its values accumulated in the forward, and - //# use the accumulated values as the predicate for this backprop switch. - var grad_state = grad_ctxt.grad_state; - var real_pred = grad_state.history_map[pred.name] as Tensor; - if (real_pred == null) - { - //# Remember the value of pred for every iteration. - grad_ctxt = grad_state.grad_context; - grad_ctxt.Exit(); - var history_pred = grad_state.AddForwardAccumulator(pred); - grad_ctxt.Enter(); - - //# Add the stack pop op. If pred.op is in a (outer) CondContext, - //# the stack pop will be guarded with a switch. - real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); - grad_state.history_map[pred.name] = real_pred; - } - pred = real_pred; - } - var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); - return results; + case CondContext ccond: + { + var pred = ccond.pred; + if (grad_ctxt != null && grad_ctxt.grad_state != null) + { + //# This Merge node is part of a cond within a loop. + //# The backprop needs to have the value of this predicate for every + //# iteration. So we must have its values accumulated in the forward, and + //# use the accumulated values as the predicate for this backprop switch. + var grad_state = grad_ctxt.grad_state; + var real_pred = grad_state.history_map[pred.name] as Tensor; + if (real_pred == null) + { + //# Remember the value of pred for every iteration. + grad_ctxt = grad_state.grad_context; + grad_ctxt.Exit(); + var history_pred = grad_state.AddForwardAccumulator(pred); + grad_ctxt.Enter(); + + //# Add the stack pop op. If pred.op is in a (outer) CondContext, + //# the stack pop will be guarded with a switch. + real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); + grad_state.history_map[pred.name] = real_pred; + } + pred = real_pred; + } + var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); + return results; } - default: - { - var num_inputs = op.inputs.Length; - var cond = new Tensor[num_inputs]; - for (int i = 0; i < num_inputs; i++) - cond[i] = math_ops.equal(op.outputs[1], i); - var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); - return result; + default: + { + var num_inputs = op.inputs.Length; + var cond = new Tensor[num_inputs]; + for (int i = 0; i < num_inputs; i++) + cond[i] = math_ops.equal(op.outputs[1], i); + var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); + return result; } } } + [RegisterGradient("RefMerge")] public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) { return _MergeGrad(op, grads); @@ -153,6 +156,7 @@ namespace Tensorflow.Gradients /// /// Gradients for an exit op are calculated using an Enter op. /// + [RegisterGradient("Exit")] public Tensor[] _ExitGrad(Operation op, Tensor[] grads) { throw new NotImplementedException("_ExitGrad"); @@ -197,14 +201,16 @@ namespace Tensorflow.Gradients /// /// Note that the backprop next_iteration is added in switch grad. /// - public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) + [RegisterGradient("NextIteration")] + public Tensor[] _NextIterationGrad(object _, Tensor[] grad) { - return (_, grad); + return grad; } - public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) + [RegisterGradient("RefNextIteration")] + public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) { - return (_, grad); + return grad; } /// @@ -213,7 +219,8 @@ namespace Tensorflow.Gradients /// For loop variables, grad is the gradient so just add an exit. /// For loop invariants, we need to add an accumulator loop. /// - public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) + [RegisterGradient("Enter")] + public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) { throw new NotImplementedException("_EnterGrad"); // graph = ops.get_default_graph() @@ -242,7 +249,9 @@ namespace Tensorflow.Gradients // return result } - public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) + + [RegisterGradient("RefEnter")] + public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad) { return _EnterGrad(op, grad); } @@ -250,10 +259,11 @@ namespace Tensorflow.Gradients /// /// Stop backprop for the predicate of a while loop. /// - public object _LoopCondGrad(object _) + [RegisterGradient("LoopCond")] + public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad) { return null; - } - - } + } + + } } diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index bfa1d296..7252301a 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,7 +108,10 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); + if(tf.get_default_graph()._nodes_by_name.Count > 18505) + { + } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); @@ -157,8 +160,12 @@ namespace Tensorflow // therefore dC/doutput[i] is 0. foreach (var (i, out_grad) in enumerate(out_grads)) { - if (out_grad == null) + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) { + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. if (loop_state != null) ; else @@ -170,7 +177,15 @@ namespace Tensorflow { if (grad_fn != null) { - in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); + } + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); } _VerifyGeneratedGradients(in_grads, op); if (gate_gradients && in_grads.Count(x => x != null) > 1) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 48cec7a9..8b7cf7a0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -227,6 +227,10 @@ namespace Tensorflow public void add_to_collection(string name, T value) { + if(name == "update_ops") + { + + } _check_not_finalized(); if (_collections.ContainsKey(name)) (_collections[name] as List).Add(value); @@ -442,17 +446,20 @@ namespace Tensorflow case List list: t = list.Select(x => (T)(object)x).ToList(); break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; default: throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); } return t; } - public object get_collection_ref(string name) + public List get_collection_ref(string name) { if (!_collections.ContainsKey(name)) - _collections[name] = new List(); - return _collections[name]; + _collections[name] = new List(); + return _collections[name] as List; } public void prevent_feeding(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index a3ae3356..444c2dd4 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -90,7 +90,7 @@ namespace Tensorflow.Layers { foreach(var name in collection_list) { - var collection = ops.get_collection_ref(name) as List; + var collection = ops.get_collection_ref(name); foreach (var element in elements) if (!collection.Contains(element)) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 2f61f954..8e317df9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -54,6 +54,10 @@ namespace Tensorflow public void _set_control_flow_context(ControlFlowContext ctx) { + if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") + { + + } _control_flow_context = ctx; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 00ba8c78..b6811917 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -151,6 +151,11 @@ namespace Tensorflow } } + if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") + { + + } + // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index ed5411cd..faf6fec2 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -154,10 +154,10 @@ namespace Tensorflow.Train var beta2 = _call_if_callable(_beta2); var epsilon = _call_if_callable(_epsilon); - _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); - _beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); - _beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); - _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); + _lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); + _beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); + _beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); + _epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 0d6c304a..e0040ecf 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -212,7 +212,7 @@ namespace Tensorflow if (!tf.context.executing_eagerly()) { - var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List; + var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP); if (train_op != null && train_op.Contains(apply_updates)) train_op.Add(apply_updates); } @@ -373,17 +373,19 @@ namespace Tensorflow loss = _scale_loss(loss); int num_towers = 1; - - var tmp = variables.trainable_variables(); - var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); - switch (tmp) + if(var_list == null) { - case List values: - var_list = values.Concat(vars).ToList(); - break; - case List values: - var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); - break; + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var tmp = variables.trainable_variables(); + switch (tmp) + { + case List values: + var_list = values.Concat(vars).ToList(); + break; + case List values: + var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); + break; + } } var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index c9b60d32..7fe1a891 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -133,7 +133,7 @@ namespace Tensorflow var check_collection_list = graph.get_all_collection_keys(); foreach (var collection_type in check_collection_list) { - var cols = graph.get_collection(collection_type); + /*var cols = graph.get_collection(collection_type); switch (cols) { case List values: @@ -165,7 +165,7 @@ namespace Tensorflow break; default: throw new NotImplementedException("_build_internal.check_collection_list"); - } + }*/ } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 9aa334db..6ab9feae 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -73,9 +73,9 @@ namespace Tensorflow return get_default_graph().get_collection(key, scope); } - public static object get_collection_ref(string key) + public static List get_collection_ref(string key) { - return get_default_graph().get_collection_ref(key); + return get_default_graph().get_collection_ref(key); } /// diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index cad2fbb5..73f5c213 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -52,6 +52,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Tensor learn_rate; Tensor loss; List first_stage_trainable_var_list; + Operation train_op_with_frozen_variables; + Operation train_op_with_all_variables; #endregion public bool Run() @@ -153,6 +155,33 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO var adam = tf.train.AdamOptimizer(learn_rate); var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list); + tf_with(tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate + { + tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate + { + tf_with(tf.control_dependencies(new[] { moving_ave }), delegate + { + train_op_with_frozen_variables = tf.no_op(); + }); + }); + }); + }); + + tf_with(tf.name_scope("define_second_stage_train"), delegate + { + var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList(); + var adam = tf.train.AdamOptimizer(learn_rate); + var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list); + tf_with(tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate + { + tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate + { + tf_with(tf.control_dependencies(new[] { moving_ave }), delegate + { + train_op_with_all_variables = tf.no_op(); + }); + }); + }); }); return graph; diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs index 6150fa90..9a2351b2 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text var y_one_hot = tf.one_hot(y, num_class); loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); - var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List; + var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS); tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate { var adam = tf.train.AdamOptimizer(learning_rate); diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index c3ba6277..6a117ac1 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -422,7 +422,7 @@ namespace TensorFlowNET.UnitTest new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); - var logits = (tf.get_collection("logits") as List)[0] as Tensor; + var logits = tf.get_collection("logits")[0] as Tensor; var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, logits: logits); } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 2086bb36..b5d37d35 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1495,6 +1495,7 @@ namespace TensorFlowNET.UnitTest #endregion } + [Ignore("Not finished yet")] [TestMethod] public void map_fn() {