Browse Source

make get_collection<T> generic.

tags/v0.12
Oceania2018 6 years ago
parent
commit
5c891eaa49
17 changed files with 172 additions and 99 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  2. +3
    -3
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  4. +77
    -67
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
  5. +17
    -2
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  6. +10
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  8. +4
    -0
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  9. +5
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  10. +4
    -4
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  11. +13
    -11
      src/TensorFlowNET.Core/Train/Optimizer.cs
  12. +2
    -2
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  13. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  14. +29
    -0
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  15. +1
    -1
      test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs
  16. +1
    -1
      test/TensorFlowNET.UnitTest/GraphTest.cs
  17. +1
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -54,7 +54,7 @@ namespace Tensorflow
maximum_iterations: maximum_iterations,
return_same_structure: return_same_structure);

public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> ops.control_dependencies(control_inputs);
}
}

+ 3
- 3
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -36,8 +36,8 @@ namespace Tensorflow
public void device(string device_name)
=> get_default_graph().device(device_name);

public object get_collection(string key, string scope = "")
=> get_default_graph().get_collection(key, scope: scope);
public List<T> get_collection<T>(string key, string scope = "")
=> get_default_graph().get_collection<T>(key, scope: scope);

/// <summary>
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
@@ -60,7 +60,7 @@ namespace Tensorflow
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public Tensor no_op(string name = null)
public Operation no_op(string name = null)
=> gen_control_flow_ops.no_op(name: name);

/// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -180,7 +180,7 @@ namespace Tensorflow
var graph = ops.get_default_graph();

var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>;
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES);

if (variables != null)
{


src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs → src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -1,28 +1,30 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
/// <summary>
/// Gradients for operators defined in control_flow_ops.py.cs
/// </summary>
[RegisterGradient("control_flow_grad")]
public class control_flow_grad
{
/// <summary>
@@ -33,6 +35,7 @@ namespace Tensorflow.Gradients
/// on the second visit. A next_iteration is also added on second visit.
/// </summary>
/// <returns></returns>
[RegisterGradient("Switch")]
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
{
throw new NotImplementedException("_SwitchGrad");
@@ -83,68 +86,68 @@ namespace Tensorflow.Gradients
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
}
}
/// <summary>
/// Gradients for a Merge op are calculated using a Switch op.
/// </summary>
[RegisterGradient("Merge")]
[RegisterGradient("Merge")]
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var _ = grads[1];
var input_op = op.inputs[0].op;
var graph = ops.get_default_graph();
var op_ctxt = control_flow_util.GetOutputContext(input_op);
var grad_ctxt = graph._get_control_flow_context();
switch (op_ctxt)
{
case WhileContext cwhile:
{
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot);
case WhileContext cwhile:
{
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot);
}
case CondContext ccond:
{
var pred = ccond.pred;
if (grad_ctxt != null && grad_ctxt.grad_state != null)
{
//# This Merge node is part of a cond within a loop.
//# The backprop needs to have the value of this predicate for every
//# iteration. So we must have its values accumulated in the forward, and
//# use the accumulated values as the predicate for this backprop switch.
var grad_state = grad_ctxt.grad_state;
var real_pred = grad_state.history_map[pred.name] as Tensor;
if (real_pred == null)
{
//# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context;
grad_ctxt.Exit();
var history_pred = grad_state.AddForwardAccumulator(pred);
grad_ctxt.Enter();
//# Add the stack pop op. If pred.op is in a (outer) CondContext,
//# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred);
grad_state.history_map[pred.name] = real_pred;
}
pred = real_pred;
}
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
return results;
case CondContext ccond:
{
var pred = ccond.pred;
if (grad_ctxt != null && grad_ctxt.grad_state != null)
{
//# This Merge node is part of a cond within a loop.
//# The backprop needs to have the value of this predicate for every
//# iteration. So we must have its values accumulated in the forward, and
//# use the accumulated values as the predicate for this backprop switch.
var grad_state = grad_ctxt.grad_state;
var real_pred = grad_state.history_map[pred.name] as Tensor;
if (real_pred == null)
{
//# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context;
grad_ctxt.Exit();
var history_pred = grad_state.AddForwardAccumulator(pred);
grad_ctxt.Enter();
//# Add the stack pop op. If pred.op is in a (outer) CondContext,
//# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred);
grad_state.history_map[pred.name] = real_pred;
}
pred = real_pred;
}
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
return results;
}
default:
{
var num_inputs = op.inputs.Length;
var cond = new Tensor[num_inputs];
for (int i = 0; i < num_inputs; i++)
cond[i] = math_ops.equal(op.outputs[1], i);
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray();
return result;
default:
{
var num_inputs = op.inputs.Length;
var cond = new Tensor[num_inputs];
for (int i = 0; i < num_inputs; i++)
cond[i] = math_ops.equal(op.outputs[1], i);
var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray();
return result;
}
}

}

[RegisterGradient("RefMerge")]
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
{
return _MergeGrad(op, grads);
@@ -153,6 +156,7 @@ namespace Tensorflow.Gradients
/// <summary>
/// Gradients for an exit op are calculated using an Enter op.
/// </summary>
[RegisterGradient("Exit")]
public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("_ExitGrad");
@@ -197,14 +201,16 @@ namespace Tensorflow.Gradients
///
/// Note that the backprop next_iteration is added in switch grad.
/// </summary>
public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad)
[RegisterGradient("NextIteration")]
public Tensor[] _NextIterationGrad(object _, Tensor[] grad)
{
return (_, grad);
return grad;
}

public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad)
[RegisterGradient("RefNextIteration")]
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
{
return (_, grad);
return grad;
}

/// <summary>
@@ -213,7 +219,8 @@ namespace Tensorflow.Gradients
/// For loop variables, grad is the gradient so just add an exit.
/// For loop invariants, we need to add an accumulator loop.
/// </summary>
public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad)
[RegisterGradient("Enter")]
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad)
{
throw new NotImplementedException("_EnterGrad");
// graph = ops.get_default_graph()
@@ -242,7 +249,9 @@ namespace Tensorflow.Gradients
// return result
}

public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad)

[RegisterGradient("RefEnter")]
public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad)
{
return _EnterGrad(op, grad);
}
@@ -250,10 +259,11 @@ namespace Tensorflow.Gradients
/// <summary>
/// Stop backprop for the predicate of a while loop.
/// </summary>
public object _LoopCondGrad(object _)
[RegisterGradient("LoopCond")]
public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad)
{
return null;
}
}
}
}
}

+ 17
- 2
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,7 +108,10 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count > 18505)
{

}
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
@@ -157,8 +160,12 @@ namespace Tensorflow
// therefore dC/doutput[i] is 0.
foreach (var (i, out_grad) in enumerate(out_grads))
{
if (out_grad == null)
if (out_grad == null &&
(grad_fn == null || _IsTrainable(op.outputs[i])))
{
// Only trainable outputs or outputs for a function call that
// will use SymbolicGradient get a zero gradient. Gradient
// functions should ignore the gradient for other outputs.
if (loop_state != null)
;
else
@@ -170,7 +177,15 @@ namespace Tensorflow
{
if (grad_fn != null)
{
in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn);
in_grads = _MaybeCompile(grad_scope,
op,
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
null,
grad_fn);
}
else
{
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
}
_VerifyGeneratedGradients(in_grads, op);
if (gate_gradients && in_grads.Count(x => x != null) > 1)


+ 10
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -227,6 +227,10 @@ namespace Tensorflow

public void add_to_collection<T>(string name, T value)
{
if(name == "update_ops")
{
}
_check_not_finalized();
if (_collections.ContainsKey(name))
(_collections[name] as List<T>).Add(value);
@@ -442,17 +446,20 @@ namespace Tensorflow
case List<Tensor> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Operation> list:
t = list.Select(x => (T)(object)x).ToList();
break;
default:
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
}
return t;
}

public object get_collection_ref(string name)
public List<T> get_collection_ref<T>(string name)
{
if (!_collections.ContainsKey(name))
_collections[name] = new List<object>();
return _collections[name];
_collections[name] = new List<T>();
return _collections[name] as List<T>;
}

public void prevent_feeding(Tensor tensor)


+ 1
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -90,7 +90,7 @@ namespace Tensorflow.Layers
{
foreach(var name in collection_list)
{
var collection = ops.get_collection_ref(name) as List<object>;
var collection = ops.get_collection_ref<Operation>(name);

foreach (var element in elements)
if (!collection.Contains(element))


+ 4
- 0
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -54,6 +54,10 @@ namespace Tensorflow
public void _set_control_flow_context(ControlFlowContext ctx)
{
if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1")
{
}
_control_flow_context = ctx;
}


+ 5
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -151,6 +151,11 @@ namespace Tensorflow
}
}

if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1")
{
}

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();


+ 4
- 4
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -154,10 +154,10 @@ namespace Tensorflow.Train
var beta2 = _call_if_callable(_beta2);
var epsilon = _call_if_callable(_epsilon);

_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
_beta1_t = ops.convert_to_tensor(beta1, name: "beta1");
_beta2_t = ops.convert_to_tensor(beta2, name: "beta2");
_epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon");
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate");
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1");
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2");
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon");
}
}
}

+ 13
- 11
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -212,7 +212,7 @@ namespace Tensorflow

if (!tf.context.executing_eagerly())
{
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
var train_op = ops.get_collection_ref<Operation>(tf.GraphKeys.TRAIN_OP);
if (train_op != null && train_op.Contains(apply_updates))
train_op.Add(apply_updates);
}
@@ -373,17 +373,19 @@ namespace Tensorflow
loss = _scale_loss(loss);
int num_towers = 1;


var tmp = variables.trainable_variables();
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
switch (tmp)
if(var_list == null)
{
case List<RefVariable> values:
var_list = values.Concat(vars).ToList();
break;
case List<VariableV1> values:
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
break;
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
var tmp = variables.trainable_variables();
switch (tmp)
{
case List<RefVariable> values:
var_list = values.Concat(vars).ToList();
break;
case List<VariableV1> values:
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
break;
}
}

var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();


+ 2
- 2
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -133,7 +133,7 @@ namespace Tensorflow
var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list)
{
var cols = graph.get_collection(collection_type);
/*var cols = graph.get_collection(collection_type);
switch (cols)
{
case List<Tensor> values:
@@ -165,7 +165,7 @@ namespace Tensorflow
break;
default:
throw new NotImplementedException("_build_internal.check_collection_list");
}
}*/
}



+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -73,9 +73,9 @@ namespace Tensorflow
return get_default_graph().get_collection<T>(key, scope);
}

public static object get_collection_ref(string key)
public static List<T> get_collection_ref<T>(string key)
{
return get_default_graph().get_collection_ref(key);
return get_default_graph().get_collection_ref<T>(key);
}

/// <summary>


+ 29
- 0
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -52,6 +52,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
Tensor learn_rate;
Tensor loss;
List<RefVariable> first_stage_trainable_var_list;
Operation train_op_with_frozen_variables;
Operation train_op_with_all_variables;
#endregion

public bool Run()
@@ -153,6 +155,33 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

var adam = tf.train.AdamOptimizer(learn_rate);
var first_stage_optimizer = adam.minimize(loss, var_list: first_stage_trainable_var_list);
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
{
tf_with(tf.control_dependencies(new ITensorOrOperation[] { first_stage_optimizer, global_step_update }), delegate
{
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
{
train_op_with_frozen_variables = tf.no_op();
});
});
});
});

tf_with(tf.name_scope("define_second_stage_train"), delegate
{
var second_stage_trainable_var_list = tf.trainable_variables().Select(x => x as RefVariable).ToList();
var adam = tf.train.AdamOptimizer(learn_rate);
var second_stage_optimizer = adam.minimize(loss, var_list: second_stage_trainable_var_list);
tf_with(tf.control_dependencies(tf.get_collection<Operation>(tf.GraphKeys.UPDATE_OPS).ToArray()), delegate
{
tf_with(tf.control_dependencies(new ITensorOrOperation[] { second_stage_optimizer, global_step_update }), delegate
{
tf_with(tf.control_dependencies(new[] { moving_ave }), delegate
{
train_op_with_all_variables = tf.no_op();
});
});
});
});

return graph;


+ 1
- 1
test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs View File

@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text
var y_one_hot = tf.one_hot(y, num_class);
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));

var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>;
var update_ops = tf.get_collection<object>(tf.GraphKeys.UPDATE_OPS);
tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate
{
var adam = tf.train.AdamOptimizer(learning_rate);


+ 1
- 1
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -422,7 +422,7 @@ namespace TensorFlowNET.UnitTest
new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels);
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0] as Tensor;
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor;
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels,
logits: logits);
}


+ 1
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -1495,6 +1495,7 @@ namespace TensorFlowNET.UnitTest
#endregion
}

[Ignore("Not finished yet")]
[TestMethod]
public void map_fn()
{


Loading…
Cancel
Save