diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 8fd922d3..9df2f333 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -34,26 +34,7 @@ namespace Tensorflow
///
public static Tensor reduce_sum(Tensor input, int[] axis = null)
{
- Tensor rank;
- string name;
- using (var namescop = new ops.name_scope("", "Rank", new List { input }))
- {
- name = namescop;
- rank = gen_array_ops.rank(input, namescop);
- }
-
- using (var namescope = new ops.name_scope("range", "Range", new List { 0D, input, 1D }))
- {
- name = namescope;
- var start = ops.convert_to_tensor(0D);
- var limit = ops.convert_to_tensor(input);
- var delta = ops.convert_to_tensor(1D);
-
- var t = gen_math_ops.range(start, limit, delta, name);
- }
-
- var s = gen_math_ops.sum(input, rank);
- return s;
+ return math_ops.reduce_sum(input);
}
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
index 456158e5..71801385 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
@@ -181,7 +181,11 @@ namespace Tensorflow
{
foreach(var x in _NonEagerInputs(op, xs))
{
- pending_count[x.op.Name] -= 1;
+ if (!pending_count.ContainsKey(x.op.Name))
+ pending_count[x.op.Name] = 0;
+ else
+ pending_count[x.op.Name] -= 1;
+
var ready = pending_count[x.op.Name] == 0;
if(loop_state != null && !ready)
@@ -440,22 +444,43 @@ namespace Tensorflow
reached_ops.Add(op);
foreach (var output in op.outputs)
{
- var c = _Consumers(output, func_graphs).ToList();
- c.ForEach(x => queue.Enqueue(x));
+ if (_IsBackpropagatable(output))
+ {
+ var c = _Consumers(output, func_graphs).ToList();
+ c.ForEach(x => queue.Enqueue(x));
+ }
}
}
}
}
+ private static bool _IsTrainable(Tensor tensor)
+ {
+ var dtype = tensor.dtype.as_base_dtype();
+ return new TF_DataType[] {TF_DataType.TF_HALF, TF_DataType.TF_FLOAT, TF_DataType.TF_DOUBLE,
+ TF_DataType.TF_COMPLEX64, TF_DataType.TF_COMPLEX128, TF_DataType.TF_RESOURCE}.Contains(dtype);
+ }
+ private static bool _IsBackpropagatable(Tensor tensor)
+ {
+ if(_IsTrainable(tensor))
+ {
+ return true;
+ }
+ else
+ {
+ var dtype = tensor.dtype.as_base_dtype();
+ return new TF_DataType[] { TF_DataType.TF_BFLOAT16, TF_DataType.TF_VARIANT }.Contains(dtype);
+ }
+ }
+
///
/// Returns the consumers of t, crossing closure boundaries where necessary.
///
///
///
- private static List _Consumers(Tensor t, List