Browse Source

gradients_utils: minor changes

tags/v0.12
Eli Belash 6 years ago
parent
commit
70c877d18f
2 changed files with 5 additions and 3 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  2. +2
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings

+ 3
- 3
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
string grad_scope = scope; string grad_scope = scope;
// Get a uid for this call to gradients that can be used to help // Get a uid for this call to gradients that can be used to help
// cluster ops for compilation. // cluster ops for compilation.
var gradient_uid = ops.get_default_graph().unique_name("uid");
var gradient_uid = curr_graph.unique_name("uid");
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);
@@ -80,7 +80,7 @@ namespace Tensorflow
var to_ops = ys.Select(x => x.op).ToList(); var to_ops = ys.Select(x => x.op).ToList();
var from_ops = xs.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList();
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
(var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);


foreach (var (y, grad_y) in zip(ys, grad_ys)) foreach (var (y, grad_y) in zip(ys, grad_ys))
_SetGrad(grads, y, grad_y); _SetGrad(grads, y, grad_y);
@@ -168,7 +168,7 @@ namespace Tensorflow
{ {
if (in_grad != null) if (in_grad != null)
{ {
if (in_grad is Tensor &&
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE) t_in.dtype != TF_DataType.TF_RESOURCE)
{ {


+ 2
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings View File

@@ -0,0 +1,2 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=utilities/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

Loading…
Cancel
Save