Browse Source

Remove unnecessary parameters for RecordGradient.

tags/v0.30
Oceania2018 4 years ago
parent
commit
76350f74f2
11 changed files with 65 additions and 29 deletions
  1. +31
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +12
    -4
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  3. +0
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
  4. +0
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs
  5. +3
    -5
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  6. +3
    -1
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  7. +0
    -3
      src/TensorFlowNET.Core/Gradients/Tape.cs
  8. +12
    -5
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  10. +3
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  11. +0
    -3
      src/TensorFlowNET.Keras/KerasApi.cs

+ 31
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Contexts
{
@@ -114,6 +115,36 @@ namespace Tensorflow.Contexts
}
}

[DebuggerStepThrough]
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
Func<Tensors> eagerAction,
Action<Operation> recordGradient,
Tensors tensors)
{
var shouldRunInEager = executing_eagerly()
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length;

if (shouldRunInEager)
return eagerAction();
else
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
restore_mode();
return result;
}
else
{
var result = graphAction();
if (tf.Runner.MustRecordGradient())
recordGradient(result[0].op);
return result;
}
}
}

public void Dispose()
=> Handle.Dispose();
}


+ 12
- 4
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -11,7 +11,8 @@ namespace Tensorflow.Eager
public bool RecordGradient(string op_name,
Tensor[] inputs,
object[] attrs,
Tensor[] results)
Tensor[] results,
Func<BackwardFunction> getBackwardFunction = null)
{
var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs);
@@ -77,13 +78,20 @@ namespace Tensorflow.Eager
else
op_inputs = inputs;

TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes,
() => GetGradientFunction(op_name, inputs, attrs, results));

TapeSetRecordOperation(op_name, inputs, results,
getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results));

return true;
}

Func<BackwardFunction> GetBackwradFunction(string op_name,
Tensor[] op_inputs,
object[] attrs,
Tensor[] op_outputs)
{
return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs);
}

BackwardFunction GetGradientFunction(string op_name,
Tensor[] op_inputs,
object[] attrs,


+ 0
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs View File

@@ -10,8 +10,6 @@ namespace Tensorflow.Eager
void TapeSetRecordBackprop(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter)
{
if (!CouldBackprop())
@@ -22,7 +20,6 @@ namespace Tensorflow.Eager
foreach (var tape in tf.GetTapeSet())
{
tape.RecordOperation(op_type, input_tensors, output_tensors,
input_ids, input_dtypes,
backward_function_getter);
}
}


+ 0
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs View File

@@ -9,8 +9,6 @@ namespace Tensorflow.Eager
bool TapeSetRecordForwardprop(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter)
{
if (!CouldForwardprop())


+ 3
- 5
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -7,11 +7,9 @@ namespace Tensorflow.Eager
{
public partial class EagerRunner
{
bool TapeSetRecordOperation(string op_type,
public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors,
Tensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter)
{
var output_info = new List<TapeTensor>();
@@ -20,11 +18,11 @@ namespace Tensorflow.Eager
return false;

if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(),
input_ids, input_dtypes, backward_function_getter))
backward_function_getter))
return false;

TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(),
input_ids, input_dtypes, backward_function_getter);
backward_function_getter);

return true;
}


+ 3
- 1
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -1,6 +1,7 @@
using System;
using Tensorflow.Contexts;
using Tensorflow.Gradients;
using static Tensorflow.tensorflow;

namespace Tensorflow.Eager
{
@@ -37,7 +38,8 @@ namespace Tensorflow.Eager
bool RecordGradient(string op_name,
Tensor[] inputs,
object[] attrs,
Tensor[] results);
Tensor[] results,
Func<BackwardFunction> getBackwardFunction = null);

bool MustRecordGradient();



+ 0
- 3
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -54,10 +54,7 @@ namespace Tensorflow.Gradients
if (tensor_tape_.find(tensor_ids[i]))
{
if (IsDtypeTrainable(dtypes[i]))
{
tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}");
return true;
}
}
}
return false;


+ 12
- 5
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Tensorflow.Gradients;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -47,11 +48,17 @@ namespace Tensorflow
{
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
(oper, out_grads) =>
g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[]
{
tf.Logger.Debug($"Caculate Gradient: {m.Name}");
var results = g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[];
foreach (var result in results.Where(x => x != null))
tf.Logger.Debug($"{result.TensorShape}");
return results;
}
);
}



+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -358,7 +358,7 @@ namespace Tensorflow
}

var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy });
return op.output;
}



+ 3
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -335,9 +335,10 @@ namespace Tensorflow.Keras.Engine

var layer_inputs = node.MapArguments(tensor_dict);

tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}");
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, is_training: training);

foreach (var output in outputs.Where(x => x != null))
tf.Logger.Debug($"{output.TensorShape}");
// Update tensor_dict for next input
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));


+ 0
- 3
src/TensorFlowNET.Keras/KerasApi.cs View File

@@ -4,9 +4,6 @@ namespace Tensorflow
{
public static class KerasApi
{
public static KerasInterface Keras(this tensorflow tf)
=> new KerasInterface();

public static KerasInterface keras { get; } = new KerasInterface();
}
}

Loading…
Cancel
Save