Browse Source

fix object reference issue for _AggregatedGrads #303

tags/v0.10
Oceania2018 6 years ago
parent
commit
99bd08b1da
4 changed files with 46 additions and 14 deletions
  1. +2
    -2
      src/TensorFlowHub/TensorFlowHub.csproj
  2. +10
    -11
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  4. +33
    -0
      test/TensorFlowNET.UnitTest/GradientTest.cs

+ 2
- 2
src/TensorFlowHub/TensorFlowHub.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<AssemblyName>TensorFlow.Net.Hub</AssemblyName> <AssemblyName>TensorFlow.Net.Hub</AssemblyName>
<RootNamespace>Tensorflow.Hub</RootNamespace> <RootNamespace>Tensorflow.Hub</RootNamespace>
@@ -8,7 +8,7 @@
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="NumSharp" Version="0.10.4" />
<PackageReference Include="NumSharp" Version="0.10.5" />
<PackageReference Include="sharpcompress" Version="0.23.0" /> <PackageReference Include="sharpcompress" Version="0.23.0" />
</ItemGroup> </ItemGroup>
</Project> </Project>

+ 10
- 11
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -137,7 +137,7 @@ namespace Tensorflow
if (loop_state != null) if (loop_state != null)
; ;
else else
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i);
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) };
} }
} }


@@ -146,7 +146,7 @@ namespace Tensorflow
string name1 = scope1; string name1 = scope1;
if (grad_fn != null) if (grad_fn != null)
{ {
in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn);
in_grads = _MaybeCompile(grad_scope, op, out_grads[0].ToArray(), null, grad_fn);
_VerifyGeneratedGradients(in_grads, op); _VerifyGeneratedGradients(in_grads, op);
} }


@@ -310,10 +310,9 @@ namespace Tensorflow
yield return op.inputs[i]; yield return op.inputs[i];
} }


private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
private static List<List<Tensor>> _AggregatedGrads(Dictionary<string, List<List<Tensor>>> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
{ {
var out_grads = _GetGrads(grads, op); var out_grads = _GetGrads(grads, op);
var return_grads = new Tensor[out_grads.Count];


foreach (var (i, out_grad) in enumerate(out_grads)) foreach (var (i, out_grad) in enumerate(out_grads))
{ {
@@ -334,21 +333,21 @@ namespace Tensorflow
throw new ValueError("_AggregatedGrads out_grad.Length == 0"); throw new ValueError("_AggregatedGrads out_grad.Length == 0");
} }


return_grads[i] = out_grad[0];
out_grads[i] = out_grad;
} }
else else
{ {
used = "add_n"; used = "add_n";
return_grads[i] = _MultiDeviceAddN(out_grad.ToArray(), gradient_uid);
out_grads[i] = new List<Tensor> { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) };
} }
} }
else else
{ {
return_grads[i] = null;
out_grads[i] = null;
} }
} }


return return_grads;
return out_grads;
} }


/// <summary> /// <summary>
@@ -362,7 +361,7 @@ namespace Tensorflow
// Basic function structure comes from control_flow_ops.group(). // Basic function structure comes from control_flow_ops.group().
// Sort tensors according to their devices. // Sort tensors according to their devices.
var tensors_on_device = new Dictionary<string, List<Tensor>>(); var tensors_on_device = new Dictionary<string, List<Tensor>>();
foreach (var tensor in tensor_list) foreach (var tensor in tensor_list)
{ {
if (!tensors_on_device.ContainsKey(tensor.Device)) if (!tensors_on_device.ContainsKey(tensor.Device))
@@ -370,10 +369,10 @@ namespace Tensorflow


tensors_on_device[tensor.Device].Add(tensor); tensors_on_device[tensor.Device].Add(tensor);
} }
// For each device, add the tensors on that device first. // For each device, add the tensors on that device first.
var summands = new List<Tensor>(); var summands = new List<Tensor>();
foreach(var dev in tensors_on_device.Keys)
foreach (var dev in tensors_on_device.Keys)
{ {
var tensors = tensors_on_device[dev]; var tensors = tensors_on_device[dev];
ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true);


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.ImageProcess
/// </summary> /// </summary>
public class DigitRecognitionRNN : IExample public class DigitRecognitionRNN : IExample
{ {
public bool Enabled { get; set; } = true;
public bool Enabled { get; set; } = false;
public bool IsImportingGraph { get; set; } = false; public bool IsImportingGraph { get; set; } = false;


public string Name => "MNIST RNN"; public string Name => "MNIST RNN";


+ 33
- 0
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -2,6 +2,7 @@
using NumSharp; using NumSharp;
using System.Linq; using System.Linq;
using Tensorflow; using Tensorflow;
using static Tensorflow.Python;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -28,6 +29,38 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(g[1].name, "gradients/Fill:0"); Assert.AreEqual(g[1].name, "gradients/Fill:0");
} }


[TestMethod]
public void Gradient2x()
{
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess => {
var x = tf.constant(7.0f);
var y = x * x * tf.constant(0.1f);

var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");

float r = sess.run(grad[0]);
Assert.AreEqual(r, 1.4f);
});
}

[TestMethod]
public void Gradient3x()
{
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess => {
var x = tf.constant(7.0f);
var y = x * x * x * tf.constant(0.1f);

var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");

float r = sess.run(grad[0]);
Assert.AreEqual(r, 14.700001f);
});
}

[TestMethod] [TestMethod]
public void StridedSlice() public void StridedSlice()
{ {


Loading…
Cancel
Save