@@ -54,7 +54,7 @@ namespace Tensorflow | |||||
maximum_iterations: maximum_iterations, | maximum_iterations: maximum_iterations, | ||||
return_same_structure: return_same_structure); | return_same_structure: return_same_structure); | ||||
public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||||
=> ops.control_dependencies(control_inputs); | => ops.control_dependencies(control_inputs); | ||||
} | } | ||||
} | } |
@@ -36,8 +36,8 @@ namespace Tensorflow | |||||
public void device(string device_name) | public void device(string device_name) | ||||
=> get_default_graph().device(device_name); | => get_default_graph().device(device_name); | ||||
public object get_collection(string key, string scope = "") | |||||
=> get_default_graph().get_collection(key, scope: scope); | |||||
public List<T> get_collection<T>(string key, string scope = "") | |||||
=> get_default_graph().get_collection<T>(key, scope: scope); | |||||
/// <summary> | /// <summary> | ||||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | /// A context manager that lifts ops out of control-flow scopes and function-building graphs. | ||||
@@ -60,7 +60,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor no_op(string name = null) | |||||
public Operation no_op(string name = null) | |||||
=> gen_control_flow_ops.no_op(name: name); | => gen_control_flow_ops.no_op(name: name); | ||||
/// <summary> | /// <summary> | ||||
@@ -180,7 +180,7 @@ namespace Tensorflow | |||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||||
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
if (variables != null) | if (variables != null) | ||||
{ | { | ||||
@@ -1,28 +1,30 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Gradients for operators defined in control_flow_ops.py.cs | /// Gradients for operators defined in control_flow_ops.py.cs | ||||
/// </summary> | /// </summary> | ||||
[RegisterGradient("control_flow_grad")] | |||||
public class control_flow_grad | public class control_flow_grad | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
@@ -33,6 +35,7 @@ namespace Tensorflow.Gradients | |||||
/// on the second visit. A next_iteration is also added on second visit. | /// on the second visit. A next_iteration is also added on second visit. | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[RegisterGradient("Switch")] | |||||
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | ||||
{ | { | ||||
throw new NotImplementedException("_SwitchGrad"); | throw new NotImplementedException("_SwitchGrad"); | ||||
@@ -83,68 +86,68 @@ namespace Tensorflow.Gradients | |||||
// false_grad = switch(grad[0], op.inputs[1])[0] | // false_grad = switch(grad[0], op.inputs[1])[0] | ||||
// true_grad = switch(grad[1], op.inputs[1])[1] | // true_grad = switch(grad[1], op.inputs[1])[1] | ||||
// return merge([false_grad, true_grad])[0], None | // return merge([false_grad, true_grad])[0], None | ||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Gradients for a Merge op are calculated using a Switch op. | /// Gradients for a Merge op are calculated using a Switch op. | ||||
/// </summary> | /// </summary> | ||||
[RegisterGradient("Merge")] | |||||
[RegisterGradient("Merge")] | |||||
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
var grad = grads[0]; | var grad = grads[0]; | ||||
var _ = grads[1]; | |||||
var input_op = op.inputs[0].op; | var input_op = op.inputs[0].op; | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
var op_ctxt = control_flow_util.GetOutputContext(input_op); | var op_ctxt = control_flow_util.GetOutputContext(input_op); | ||||
var grad_ctxt = graph._get_control_flow_context(); | var grad_ctxt = graph._get_control_flow_context(); | ||||
switch (op_ctxt) | switch (op_ctxt) | ||||
{ | { | ||||
case WhileContext cwhile: | |||||
{ | |||||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||||
case WhileContext cwhile: | |||||
{ | |||||
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); | |||||
} | } | ||||
case CondContext ccond: | |||||
{ | |||||
var pred = ccond.pred; | |||||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||||
{ | |||||
//# This Merge node is part of a cond within a loop. | |||||
//# The backprop needs to have the value of this predicate for every | |||||
//# iteration. So we must have its values accumulated in the forward, and | |||||
//# use the accumulated values as the predicate for this backprop switch. | |||||
var grad_state = grad_ctxt.grad_state; | |||||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||||
if (real_pred == null) | |||||
{ | |||||
//# Remember the value of pred for every iteration. | |||||
grad_ctxt = grad_state.grad_context; | |||||
grad_ctxt.Exit(); | |||||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||||
grad_ctxt.Enter(); | |||||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||||
//# the stack pop will be guarded with a switch. | |||||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||||
grad_state.history_map[pred.name] = real_pred; | |||||
} | |||||
pred = real_pred; | |||||
} | |||||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
return results; | |||||
case CondContext ccond: | |||||
{ | |||||
var pred = ccond.pred; | |||||
if (grad_ctxt != null && grad_ctxt.grad_state != null) | |||||
{ | |||||
//# This Merge node is part of a cond within a loop. | |||||
//# The backprop needs to have the value of this predicate for every | |||||
//# iteration. So we must have its values accumulated in the forward, and | |||||
//# use the accumulated values as the predicate for this backprop switch. | |||||
var grad_state = grad_ctxt.grad_state; | |||||
var real_pred = grad_state.history_map[pred.name] as Tensor; | |||||
if (real_pred == null) | |||||
{ | |||||
//# Remember the value of pred for every iteration. | |||||
grad_ctxt = grad_state.grad_context; | |||||
grad_ctxt.Exit(); | |||||
var history_pred = grad_state.AddForwardAccumulator(pred); | |||||
grad_ctxt.Enter(); | |||||
//# Add the stack pop op. If pred.op is in a (outer) CondContext, | |||||
//# the stack pop will be guarded with a switch. | |||||
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); | |||||
grad_state.history_map[pred.name] = real_pred; | |||||
} | |||||
pred = real_pred; | |||||
} | |||||
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | |||||
return results; | |||||
} | } | ||||
default: | |||||
{ | |||||
var num_inputs = op.inputs.Length; | |||||
var cond = new Tensor[num_inputs]; | |||||
for (int i = 0; i < num_inputs; i++) | |||||
cond[i] = math_ops.equal(op.outputs[1], i); | |||||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||||
return result; | |||||
default: | |||||
{ | |||||
var num_inputs = op.inputs.Length; | |||||
var cond = new Tensor[num_inputs]; | |||||
for (int i = 0; i < num_inputs; i++) | |||||
cond[i] = math_ops.equal(op.outputs[1], i); | |||||
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); | |||||
return result; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
[RegisterGradient("RefMerge")] | |||||
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
return _MergeGrad(op, grads); | return _MergeGrad(op, grads); | ||||
@@ -153,6 +156,7 @@ namespace Tensorflow.Gradients | |||||
/// <summary> | /// <summary> | ||||
/// Gradients for an exit op are calculated using an Enter op. | /// Gradients for an exit op are calculated using an Enter op. | ||||
/// </summary> | /// </summary> | ||||
[RegisterGradient("Exit")] | |||||
public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | public Tensor[] _ExitGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
throw new NotImplementedException("_ExitGrad"); | throw new NotImplementedException("_ExitGrad"); | ||||
@@ -197,14 +201,16 @@ namespace Tensorflow.Gradients | |||||
/// | /// | ||||
/// Note that the backprop next_iteration is added in switch grad. | /// Note that the backprop next_iteration is added in switch grad. | ||||
/// </summary> | /// </summary> | ||||
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) | |||||
[RegisterGradient("NextIteration")] | |||||
public Tensor[] _NextIterationGrad(object _, Tensor[] grad) | |||||
{ | { | ||||
return (_, grad); | |||||
return grad; | |||||
} | } | ||||
public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) | |||||
[RegisterGradient("RefNextIteration")] | |||||
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) | |||||
{ | { | ||||
return (_, grad); | |||||
return grad; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -213,7 +219,8 @@ namespace Tensorflow.Gradients | |||||
/// For loop variables, grad is the gradient so just add an exit. | /// For loop variables, grad is the gradient so just add an exit. | ||||
/// For loop invariants, we need to add an accumulator loop. | /// For loop invariants, we need to add an accumulator loop. | ||||
/// </summary> | /// </summary> | ||||
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) | |||||
[RegisterGradient("Enter")] | |||||
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) | |||||
{ | { | ||||
throw new NotImplementedException("_EnterGrad"); | throw new NotImplementedException("_EnterGrad"); | ||||
// graph = ops.get_default_graph() | // graph = ops.get_default_graph() | ||||
@@ -242,7 +249,9 @@ namespace Tensorflow.Gradients | |||||
// return result | // return result | ||||
} | } | ||||
public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) | |||||
[RegisterGradient("RefEnter")] | |||||
public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad) | |||||
{ | { | ||||
return _EnterGrad(op, grad); | return _EnterGrad(op, grad); | ||||
} | } | ||||
@@ -250,10 +259,11 @@ namespace Tensorflow.Gradients | |||||
/// <summary> | /// <summary> | ||||
/// Stop backprop for the predicate of a while loop. | /// Stop backprop for the predicate of a while loop. | ||||
/// </summary> | /// </summary> | ||||
public object _LoopCondGrad(object _) | |||||
[RegisterGradient("LoopCond")] | |||||
public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad) | |||||
{ | { | ||||
return null; | return null; | ||||
} | |||||
} | |||||
} | |||||
} | |||||
} | } |
@@ -108,7 +108,10 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||||
{ | |||||
} | |||||
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
//if (loop_state != null) | //if (loop_state != null) | ||||
//loop_state.EnterGradWhileContext(op, before: true); | //loop_state.EnterGradWhileContext(op, before: true); | ||||
@@ -157,8 +160,12 @@ namespace Tensorflow | |||||
// therefore dC/doutput[i] is 0. | // therefore dC/doutput[i] is 0. | ||||
foreach (var (i, out_grad) in enumerate(out_grads)) | foreach (var (i, out_grad) in enumerate(out_grads)) | ||||
{ | { | ||||
if (out_grad == null) | |||||
if (out_grad == null && | |||||
(grad_fn == null || _IsTrainable(op.outputs[i]))) | |||||
{ | { | ||||
// Only trainable outputs or outputs for a function call that | |||||
// will use SymbolicGradient get a zero gradient. Gradient | |||||
// functions should ignore the gradient for other outputs. | |||||
if (loop_state != null) | if (loop_state != null) | ||||
; | ; | ||||
else | else | ||||
@@ -170,7 +177,15 @@ namespace Tensorflow | |||||
{ | { | ||||
if (grad_fn != null) | if (grad_fn != null) | ||||
{ | { | ||||
in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); | |||||
in_grads = _MaybeCompile(grad_scope, | |||||
op, | |||||
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||||
null, | |||||
grad_fn); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||||
} | } | ||||
_VerifyGeneratedGradients(in_grads, op); | _VerifyGeneratedGradients(in_grads, op); | ||||
if (gate_gradients && in_grads.Count(x => x != null) > 1) | if (gate_gradients && in_grads.Count(x => x != null) > 1) | ||||
@@ -227,6 +227,10 @@ namespace Tensorflow | |||||
public void add_to_collection<T>(string name, T value) | public void add_to_collection<T>(string name, T value) | ||||
{ | { | ||||
if(name == "update_ops") | |||||
{ | |||||
} | |||||
_check_not_finalized(); | _check_not_finalized(); | ||||
if (_collections.ContainsKey(name)) | if (_collections.ContainsKey(name)) | ||||
(_collections[name] as List<T>).Add(value); | (_collections[name] as List<T>).Add(value); | ||||
@@ -442,17 +446,20 @@ namespace Tensorflow | |||||
case List<Tensor> list: | case List<Tensor> list: | ||||
t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
break; | break; | ||||
case List<Operation> list: | |||||
t = list.Select(x => (T)(object)x).ToList(); | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | ||||
} | } | ||||
return t; | return t; | ||||
} | } | ||||
public object get_collection_ref(string name) | |||||
public List<T> get_collection_ref<T>(string name) | |||||
{ | { | ||||
if (!_collections.ContainsKey(name)) | if (!_collections.ContainsKey(name)) | ||||
_collections[name] = new List<object>(); | |||||
return _collections[name]; | |||||
_collections[name] = new List<T>(); | |||||
return _collections[name] as List<T>; | |||||
} | } | ||||
public void prevent_feeding(Tensor tensor) | public void prevent_feeding(Tensor tensor) | ||||
@@ -90,7 +90,7 @@ namespace Tensorflow.Layers | |||||
{ | { | ||||
foreach(var name in collection_list) | foreach(var name in collection_list) | ||||
{ | { | ||||
var collection = ops.get_collection_ref(name) as List<object>; | |||||
var collection = ops.get_collection_ref<Operation>(name); | |||||
foreach (var element in elements) | foreach (var element in elements) | ||||
if (!collection.Contains(element)) | if (!collection.Contains(element)) | ||||
@@ -54,6 +54,10 @@ namespace Tensorflow | |||||
public void _set_control_flow_context(ControlFlowContext ctx) | public void _set_control_flow_context(ControlFlowContext ctx) | ||||
{ | { | ||||
if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||||
{ | |||||
} | |||||
_control_flow_context = ctx; | _control_flow_context = ctx; | ||||
} | } | ||||
@@ -151,6 +151,11 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||||
{ | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
// context managers. | // context managers. | ||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
@@ -154,10 +154,10 @@ namespace Tensorflow.Train | |||||
var beta2 = _call_if_callable(_beta2); | var beta2 = _call_if_callable(_beta2); | ||||
var epsilon = _call_if_callable(_epsilon); | var epsilon = _call_if_callable(_epsilon); | ||||
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
_beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); | |||||
_beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); | |||||
_epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); | |||||
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); | |||||
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); | |||||
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -212,7 +212,7 @@ namespace Tensorflow | |||||
if (!tf.context.executing_eagerly()) | if (!tf.context.executing_eagerly()) | ||||
{ | { | ||||
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||||
var train_op = ops.get_collection_ref<Operation>(tf.GraphKeys.TRAIN_OP); | |||||
if (train_op != null && train_op.Contains(apply_updates)) | if (train_op != null && train_op.Contains(apply_updates)) | ||||
train_op.Add(apply_updates); | train_op.Add(apply_updates); | ||||
} | } | ||||
@@ -373,17 +373,19 @@ namespace Tensorflow | |||||
loss = _scale_loss(loss); | loss = _scale_loss(loss); | ||||
int num_towers = 1; | int num_towers = 1; | ||||
var tmp = variables.trainable_variables(); | |||||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
switch (tmp) | |||||
if(var_list == null) | |||||
{ | { | ||||
case List<RefVariable> values: | |||||
var_list = values.Concat(vars).ToList(); | |||||
break; | |||||
case List<VariableV1> values: | |||||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||||
break; | |||||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
var tmp = variables.trainable_variables(); | |||||
switch (tmp) | |||||
{ | |||||
case List<RefVariable> values: | |||||
var_list = values.Concat(vars).ToList(); | |||||
break; | |||||
case List<VariableV1> values: | |||||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||||
break; | |||||
} | |||||
} | } | ||||
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | ||||
@@ -133,7 +133,7 @@ namespace Tensorflow | |||||
var check_collection_list = graph.get_all_collection_keys(); | var check_collection_list = graph.get_all_collection_keys(); | ||||
foreach (var collection_type in check_collection_list) | foreach (var collection_type in check_collection_list) | ||||
{ | { | ||||
var cols = graph.get_collection(collection_type); | |||||
/*var cols = graph.get_collection(collection_type); | |||||
switch (cols) | switch (cols) | ||||
{ | { | ||||
case List<Tensor> values: | case List<Tensor> values: | ||||
@@ -165,7 +165,7 @@ namespace Tensorflow | |||||
break; | break; | ||||
default: | default: | ||||
throw new NotImplementedException("_build_internal.check_collection_list"); | throw new NotImplementedException("_build_internal.check_collection_list"); | ||||
} | |||||
}*/ | |||||
} | } | ||||
@@ -73,9 +73,9 @@ namespace Tensorflow | |||||
return get_default_graph().get_collection<T>(key, scope); | return get_default_graph().get_collection<T>(key, scope); | ||||
} | } | ||||
public static object get_collection_ref(string key) | |||||
public static List<T> get_collection_ref<T>(string key) | |||||
{ | { | ||||
return get_default_graph().get_collection_ref(key); | |||||
return get_default_graph().get_collection_ref<T>(key); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -52,6 +52,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
Tensor learn_rate; | Tensor learn_rate; | ||||
Tensor loss; | Tensor loss; | ||||
List<RefVariable> first_stage_trainable_var_list; | List<RefVariable> first_stage_trainable_var_list; | ||||
Operation train_op_with_frozen_variables; | |||||
Operation train_op_with_all_variables; | |||||
#endregion | #endregion | ||||
public bool Run() | public bool Run() | ||||
@@ -153,6 +155,33 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
var adam = tf.train.AdamOptimizer(learn_rate); | var adam = tf.train.AdamOptimizer(learn_rate); | ||||
var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list); | var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list); | ||||
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate | |||||
{ | |||||
tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate | |||||
{ | |||||
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate | |||||
{ | |||||
train_op_with_frozen_variables = tf.no_op(); | |||||
}); | |||||
}); | |||||
}); | |||||
}); | |||||
tf_with(tf.name_scope("define_second_stage_train"), delegate | |||||
{ | |||||
var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList(); | |||||
var adam = tf.train.AdamOptimizer(learn_rate); | |||||
var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list); | |||||
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate | |||||
{ | |||||
tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate | |||||
{ | |||||
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate | |||||
{ | |||||
train_op_with_all_variables = tf.no_op(); | |||||
}); | |||||
}); | |||||
}); | |||||
}); | }); | ||||
return graph; | return graph; | ||||
@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text | |||||
var y_one_hot = tf.one_hot(y, num_class); | var y_one_hot = tf.one_hot(y, num_class); | ||||
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | ||||
var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>; | |||||
var update_ops = tf.get_collection<object>(tf.GraphKeys.UPDATE_OPS); | |||||
tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | ||||
{ | { | ||||
var adam = tf.train.AdamOptimizer(learning_rate); | var adam = tf.train.AdamOptimizer(learning_rate); | ||||
@@ -422,7 +422,7 @@ namespace TensorFlowNET.UnitTest | |||||
new_saver.restore(sess, dir + "my-model-10000"); | new_saver.restore(sess, dir + "my-model-10000"); | ||||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | ||||
var batch_size = tf.size(labels); | var batch_size = tf.size(labels); | ||||
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0] as Tensor; | |||||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | ||||
logits: logits); | logits: logits); | ||||
} | } | ||||
@@ -1495,6 +1495,7 @@ namespace TensorFlowNET.UnitTest | |||||
#endregion | #endregion | ||||
} | } | ||||
[Ignore("Not finished yet")] | |||||
[TestMethod] | [TestMethod] | ||||
public void map_fn() | public void map_fn() | ||||
{ | { | ||||