@@ -54,7 +54,7 @@ namespace Tensorflow | |||
maximum_iterations: maximum_iterations, | |||
return_same_structure: return_same_structure); | |||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
=> ops.control_dependencies(control_inputs); | |||
} | |||
} |
@@ -36,8 +36,8 @@ namespace Tensorflow | |||
public void device(string device_name) | |||
=> get_default_graph().device(device_name); | |||
public object get_collection(string key, string scope = "") | |||
=> get_default_graph().get_collection(key, scope: scope); | |||
public List<T> get_collection<T>(string key, string scope = "") | |||
=> get_default_graph().get_collection<T>(key, scope: scope); | |||
/// <summary> | |||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||
@@ -60,7 +60,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor no_op(string name = null) | |||
public Operation no_op(string name = null) | |||
=> gen_control_flow_ops.no_op(name: name); | |||
/// <summary> | |||
@@ -180,7 +180,7 @@ namespace Tensorflow | |||
var graph = ops.get_default_graph(); | |||
var var_list = new Dictionary<string, RefVariable>(); | |||
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||
if (variables != null) | |||
{ | |||
@@ -1,28 +1,30 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
/// <summary> | |||
/// Gradients for operators defined in control_flow_ops.py.cs | |||
/// </summary> | |||
[RegisterGradient("control_flow_grad")] | |||
public class control_flow_grad | |||
{ | |||
/// <summary> | |||
@@ -33,6 +35,7 @@ namespace Tensorflow.Gradients | |||
/// on the second visit. A next_iteration is also added on second visit. | |||
/// </summary> | |||
/// <returns></returns> | |||
[RegisterGradient("Switch")] | |||
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||
{ | |||
throw new NotImplementedException("_SwitchGrad"); | |||
@@ -83,68 +86,68 @@ namespace Tensorflow.Gradients | |||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||
// return merge([false_grad, true_grad])[0], None | |||
} | |||
} | |||
/// <summary> | |||
/// Gradients for a Merge op are calculated using a Switch op. | |||
/// </summary> | |||
[RegisterGradient("Merge")] | |||
[RegisterGradient("Merge")] | |||
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var _ = grads[1]; | |||
var input_op = op.inputs[0].op; | |||
var graph = ops.get_default_graph(); | |||
var op_ctxt = control_flow_util.GetOutputContext(input_op); | |||
var grad_ctxt = graph._get_control_flow_context(); | |||
switch (op_ctxt) | |||
{ | |||
case WhileContext cwhile: | |||
{ | |||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||
case WhileContext cwhile: | |||
{ | |||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||
} | |||
case CondContext ccond: | |||
{ | |||
var pred = ccond.pred; | |||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||
{ | |||
//# This Merge node is part of a cond within a loop. | |||
//# The backprop needs to have the value of this predicate for every | |||
//# iteration. So we must have its values accumulated in the forward, and | |||
//# use the accumulated values as the predicate for this backprop switch. | |||
var grad_state = grad_ctxt.grad_state; | |||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||
if (real_pred == null) | |||
{ | |||
//# Remember the value of pred for every iteration. | |||
grad_ctxt = grad_state.grad_context; | |||
grad_ctxt.Exit(); | |||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||
grad_ctxt.Enter(); | |||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||
//# the stack pop will be guarded with a switch. | |||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||
grad_state.history_map[pred.name] = real_pred; | |||
} | |||
pred = real_pred; | |||
} | |||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
return results; | |||
case CondContext ccond: | |||
{ | |||
var pred = ccond.pred; | |||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||
{ | |||
//# This Merge node is part of a cond within a loop. | |||
//# The backprop needs to have the value of this predicate for every | |||
//# iteration. So we must have its values accumulated in the forward, and | |||
//# use the accumulated values as the predicate for this backprop switch. | |||
var grad_state = grad_ctxt.grad_state; | |||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||
if (real_pred == null) | |||
{ | |||
//# Remember the value of pred for every iteration. | |||
grad_ctxt = grad_state.grad_context; | |||
grad_ctxt.Exit(); | |||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||
grad_ctxt.Enter(); | |||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||
//# the stack pop will be guarded with a switch. | |||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||
grad_state.history_map[pred.name] = real_pred; | |||
} | |||
pred = real_pred; | |||
} | |||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||
return results; | |||
} | |||
default: | |||
{ | |||
var num_inputs = op.inputs.Length; | |||
var cond = new Tensor[num_inputs]; | |||
for (int i = 0; i < num_inputs; i++) | |||
cond[i] = math_ops.equal(op.outputs[1], i); | |||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||
return result; | |||
default: | |||
{ | |||
var num_inputs = op.inputs.Length; | |||
var cond = new Tensor[num_inputs]; | |||
for (int i = 0; i < num_inputs; i++) | |||
cond[i] = math_ops.equal(op.outputs[1], i); | |||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||
return result; | |||
} | |||
} | |||
} | |||
[RegisterGradient("RefMerge")] | |||
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | |||
{ | |||
return _MergeGrad(op, grads); | |||
@@ -153,6 +156,7 @@ namespace Tensorflow.Gradients | |||
/// <summary> | |||
/// Gradients for an exit op are calculated using an Enter op. | |||
/// </summary> | |||
[RegisterGradient("Exit")] | |||
public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | |||
{ | |||
throw new NotImplementedException("_ExitGrad"); | |||
@@ -197,14 +201,16 @@ namespace Tensorflow.Gradients | |||
/// | |||
/// Note that the backprop next_iteration is added in switch grad. | |||
/// </summary> | |||
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||
[RegisterGradient("NextIteration")] | |||
public Tensor[] _NextIterationGrad(object _, Tensor[] grad) | |||
{ | |||
return (_, grad); | |||
return grad; | |||
} | |||
public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||
[RegisterGradient("RefNextIteration")] | |||
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) | |||
{ | |||
return (_, grad); | |||
return grad; | |||
} | |||
/// <summary> | |||
@@ -213,7 +219,8 @@ namespace Tensorflow.Gradients | |||
/// For loop variables, grad is the gradient so just add an exit. | |||
/// For loop invariants, we need to add an accumulator loop. | |||
/// </summary> | |||
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||
[RegisterGradient("Enter")] | |||
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) | |||
{ | |||
throw new NotImplementedException("_EnterGrad"); | |||
// graph = ops.get_default_graph() | |||
@@ -242,7 +249,9 @@ namespace Tensorflow.Gradients | |||
// return result | |||
} | |||
public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||
[RegisterGradient("RefEnter")] | |||
public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad) | |||
{ | |||
return _EnterGrad(op, grad); | |||
} | |||
@@ -250,10 +259,11 @@ namespace Tensorflow.Gradients | |||
/// <summary> | |||
/// Stop backprop for the predicate of a while loop. | |||
/// </summary> | |||
public object _LoopCondGrad(object _) | |||
[RegisterGradient("LoopCond")] | |||
public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad) | |||
{ | |||
return null; | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -108,7 +108,10 @@ namespace Tensorflow | |||
{ | |||
// generate gradient subgraph for op. | |||
var op = queue.Dequeue(); | |||
if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||
{ | |||
} | |||
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | |||
//if (loop_state != null) | |||
//loop_state.EnterGradWhileContext(op, before: true); | |||
@@ -157,8 +160,12 @@ namespace Tensorflow | |||
// therefore dC/doutput[i] is 0. | |||
foreach (var (i, out_grad) in enumerate(out_grads)) | |||
{ | |||
if (out_grad == null) | |||
if (out_grad == null && | |||
(grad_fn == null || _IsTrainable(op.outputs[i]))) | |||
{ | |||
// Only trainable outputs or outputs for a function call that | |||
// will use SymbolicGradient get a zero gradient. Gradient | |||
// functions should ignore the gradient for other outputs. | |||
if (loop_state != null) | |||
; | |||
else | |||
@@ -170,7 +177,15 @@ namespace Tensorflow | |||
{ | |||
if (grad_fn != null) | |||
{ | |||
in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); | |||
in_grads = _MaybeCompile(grad_scope, | |||
op, | |||
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||
null, | |||
grad_fn); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||
} | |||
_VerifyGeneratedGradients(in_grads, op); | |||
if (gate_gradients && in_grads.Count(x => x != null) > 1) | |||
@@ -227,6 +227,10 @@ namespace Tensorflow | |||
public void add_to_collection<T>(string name, T value) | |||
{ | |||
if(name == "update_ops") | |||
{ | |||
} | |||
_check_not_finalized(); | |||
if (_collections.ContainsKey(name)) | |||
(_collections[name] as List<T>).Add(value); | |||
@@ -442,17 +446,20 @@ namespace Tensorflow | |||
case List<Tensor> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
case List<Operation> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
default: | |||
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | |||
} | |||
return t; | |||
} | |||
public object get_collection_ref(string name) | |||
public List<T> get_collection_ref<T>(string name) | |||
{ | |||
if (!_collections.ContainsKey(name)) | |||
_collections[name] = new List<object>(); | |||
return _collections[name]; | |||
_collections[name] = new List<T>(); | |||
return _collections[name] as List<T>; | |||
} | |||
public void prevent_feeding(Tensor tensor) | |||
@@ -90,7 +90,7 @@ namespace Tensorflow.Layers | |||
{ | |||
foreach(var name in collection_list) | |||
{ | |||
var collection = ops.get_collection_ref(name) as List<object>; | |||
var collection = ops.get_collection_ref<Operation>(name); | |||
foreach (var element in elements) | |||
if (!collection.Contains(element)) | |||
@@ -54,6 +54,10 @@ namespace Tensorflow | |||
public void _set_control_flow_context(ControlFlowContext ctx) | |||
{ | |||
if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||
{ | |||
} | |||
_control_flow_context = ctx; | |||
} | |||
@@ -151,6 +151,11 @@ namespace Tensorflow | |||
} | |||
} | |||
if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||
{ | |||
} | |||
// Dict mapping op name to file and line information for op colocation | |||
// context managers. | |||
_control_flow_context = graph._get_control_flow_context(); | |||
@@ -154,10 +154,10 @@ namespace Tensorflow.Train | |||
var beta2 = _call_if_callable(_beta2); | |||
var epsilon = _call_if_callable(_epsilon); | |||
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||
_beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); | |||
_beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); | |||
_epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); | |||
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); | |||
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); | |||
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); | |||
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); | |||
} | |||
} | |||
} |
@@ -212,7 +212,7 @@ namespace Tensorflow | |||
if (!tf.context.executing_eagerly()) | |||
{ | |||
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||
var train_op = ops.get_collection_ref<Operation>(tf.GraphKeys.TRAIN_OP); | |||
if (train_op != null && train_op.Contains(apply_updates)) | |||
train_op.Add(apply_updates); | |||
} | |||
@@ -373,17 +373,19 @@ namespace Tensorflow | |||
loss = _scale_loss(loss); | |||
int num_towers = 1; | |||
var tmp = variables.trainable_variables(); | |||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
switch (tmp) | |||
if(var_list == null) | |||
{ | |||
case List<RefVariable> values: | |||
var_list = values.Concat(vars).ToList(); | |||
break; | |||
case List<VariableV1> values: | |||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||
break; | |||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
var tmp = variables.trainable_variables(); | |||
switch (tmp) | |||
{ | |||
case List<RefVariable> values: | |||
var_list = values.Concat(vars).ToList(); | |||
break; | |||
case List<VariableV1> values: | |||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||
break; | |||
} | |||
} | |||
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
@@ -133,7 +133,7 @@ namespace Tensorflow | |||
var check_collection_list = graph.get_all_collection_keys(); | |||
foreach (var collection_type in check_collection_list) | |||
{ | |||
var cols = graph.get_collection(collection_type); | |||
/*var cols = graph.get_collection(collection_type); | |||
switch (cols) | |||
{ | |||
case List<Tensor> values: | |||
@@ -165,7 +165,7 @@ namespace Tensorflow | |||
break; | |||
default: | |||
throw new NotImplementedException("_build_internal.check_collection_list"); | |||
} | |||
}*/ | |||
} | |||
@@ -73,9 +73,9 @@ namespace Tensorflow | |||
return get_default_graph().get_collection<T>(key, scope); | |||
} | |||
public static object get_collection_ref(string key) | |||
public static List<T> get_collection_ref<T>(string key) | |||
{ | |||
return get_default_graph().get_collection_ref(key); | |||
return get_default_graph().get_collection_ref<T>(key); | |||
} | |||
/// <summary> | |||
@@ -52,6 +52,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
Tensor learn_rate; | |||
Tensor loss; | |||
List<RefVariable> first_stage_trainable_var_list; | |||
Operation train_op_with_frozen_variables; | |||
Operation train_op_with_all_variables; | |||
#endregion | |||
public bool Run() | |||
@@ -153,6 +155,33 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||
var adam = tf.train.AdamOptimizer(learn_rate); | |||
var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list); | |||
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate | |||
{ | |||
tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate | |||
{ | |||
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate | |||
{ | |||
train_op_with_frozen_variables = tf.no_op(); | |||
}); | |||
}); | |||
}); | |||
}); | |||
tf_with(tf.name_scope("define_second_stage_train"), delegate | |||
{ | |||
var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList(); | |||
var adam = tf.train.AdamOptimizer(learn_rate); | |||
var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list); | |||
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate | |||
{ | |||
tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate | |||
{ | |||
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate | |||
{ | |||
train_op_with_all_variables = tf.no_op(); | |||
}); | |||
}); | |||
}); | |||
}); | |||
return graph; | |||
@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text | |||
var y_one_hot = tf.one_hot(y, num_class); | |||
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | |||
var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>; | |||
var update_ops = tf.get_collection<object>(tf.GraphKeys.UPDATE_OPS); | |||
tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | |||
{ | |||
var adam = tf.train.AdamOptimizer(learning_rate); | |||
@@ -422,7 +422,7 @@ namespace TensorFlowNET.UnitTest | |||
new_saver.restore(sess, dir + "my-model-10000"); | |||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||
var batch_size = tf.size(labels); | |||
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0] as Tensor; | |||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||
logits: logits); | |||
} | |||
@@ -1495,6 +1495,7 @@ namespace TensorFlowNET.UnitTest | |||
#endregion | |||
} | |||
[Ignore("Not finished yet")] | |||
[TestMethod] | |||
public void map_fn() | |||
{ | |||