Browse Source

Remove unnecessary parameters for RecordGradient.

tags/v0.30
Oceania2018 4 years ago
parent
commit
76350f74f2
11 changed files with 65 additions and 29 deletions
  1. +31
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +12
    -4
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  3. +0
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
  4. +0
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs
  5. +3
    -5
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  6. +3
    -1
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  7. +0
    -3
      src/TensorFlowNET.Core/Gradients/Tape.cs
  8. +12
    -5
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  10. +3
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  11. +0
    -3
      src/TensorFlowNET.Keras/KerasApi.cs

+ 31
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Eager; using Tensorflow.Eager;
using static Tensorflow.Binding;


namespace Tensorflow.Contexts namespace Tensorflow.Contexts
{ {
@@ -114,6 +115,36 @@ namespace Tensorflow.Contexts
} }
} }


[DebuggerStepThrough]
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
Func<Tensors> eagerAction,
Action<Operation> recordGradient,
Tensors tensors)
{
var shouldRunInEager = executing_eagerly()
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length;

if (shouldRunInEager)
return eagerAction();
else
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
restore_mode();
return result;
}
else
{
var result = graphAction();
if (tf.Runner.MustRecordGradient())
recordGradient(result[0].op);
return result;
}
}
}

public void Dispose() public void Dispose()
=> Handle.Dispose(); => Handle.Dispose();
} }


+ 12
- 4
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -11,7 +11,8 @@ namespace Tensorflow.Eager
public bool RecordGradient(string op_name, public bool RecordGradient(string op_name,
Tensor[] inputs, Tensor[] inputs,
object[] attrs, object[] attrs,
Tensor[] results)
Tensor[] results,
Func<BackwardFunction> getBackwardFunction = null)
{ {
var input_ids = MakeTensorIDList(inputs); var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs); var input_dtypes = MakeTensorDtypeList(inputs);
@@ -77,13 +78,20 @@ namespace Tensorflow.Eager
else else
op_inputs = inputs; op_inputs = inputs;


TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes,
() => GetGradientFunction(op_name, inputs, attrs, results));

TapeSetRecordOperation(op_name, inputs, results,
getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results));


return true; return true;
} }


Func<BackwardFunction> GetBackwradFunction(string op_name,
Tensor[] op_inputs,
object[] attrs,
Tensor[] op_outputs)
{
return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs);
}

BackwardFunction GetGradientFunction(string op_name, BackwardFunction GetGradientFunction(string op_name,
Tensor[] op_inputs, Tensor[] op_inputs,
object[] attrs, object[] attrs,


+ 0
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs View File

@@ -10,8 +10,6 @@ namespace Tensorflow.Eager
void TapeSetRecordBackprop(string op_type, void TapeSetRecordBackprop(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
if (!CouldBackprop()) if (!CouldBackprop())
@@ -22,7 +20,6 @@ namespace Tensorflow.Eager
foreach (var tape in tf.GetTapeSet()) foreach (var tape in tf.GetTapeSet())
{ {
tape.RecordOperation(op_type, input_tensors, output_tensors, tape.RecordOperation(op_type, input_tensors, output_tensors,
input_ids, input_dtypes,
backward_function_getter); backward_function_getter);
} }
} }


+ 0
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs View File

@@ -9,8 +9,6 @@ namespace Tensorflow.Eager
bool TapeSetRecordForwardprop(string op_type, bool TapeSetRecordForwardprop(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
if (!CouldForwardprop()) if (!CouldForwardprop())


+ 3
- 5
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -7,11 +7,9 @@ namespace Tensorflow.Eager
{ {
public partial class EagerRunner public partial class EagerRunner
{ {
bool TapeSetRecordOperation(string op_type,
public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
Tensor[] output_tensors, Tensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
var output_info = new List<TapeTensor>(); var output_info = new List<TapeTensor>();
@@ -20,11 +18,11 @@ namespace Tensorflow.Eager
return false; return false;


if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(), if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(),
input_ids, input_dtypes, backward_function_getter))
backward_function_getter))
return false; return false;


TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(), TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(),
input_ids, input_dtypes, backward_function_getter);
backward_function_getter);


return true; return true;
} }


+ 3
- 1
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using static Tensorflow.tensorflow;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -37,7 +38,8 @@ namespace Tensorflow.Eager
bool RecordGradient(string op_name, bool RecordGradient(string op_name,
Tensor[] inputs, Tensor[] inputs,
object[] attrs, object[] attrs,
Tensor[] results);
Tensor[] results,
Func<BackwardFunction> getBackwardFunction = null);


bool MustRecordGradient(); bool MustRecordGradient();




+ 0
- 3
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -54,10 +54,7 @@ namespace Tensorflow.Gradients
if (tensor_tape_.find(tensor_ids[i])) if (tensor_tape_.find(tensor_ids[i]))
{ {
if (IsDtypeTrainable(dtypes[i])) if (IsDtypeTrainable(dtypes[i]))
{
tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}");
return true; return true;
}
} }
} }
return false; return false;


+ 12
- 5
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -47,11 +48,17 @@ namespace Tensorflow
{ {
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
(oper, out_grads) => (oper, out_grads) =>
g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[]
{
tf.Logger.Debug($"Caculate Gradient: {m.Name}");
var results = g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[];
foreach (var result in results.Where(x => x != null))
tf.Logger.Debug($"{result.TensorShape}");
return results;
}
); );
} }




+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -358,7 +358,7 @@ namespace Tensorflow
} }


var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy }); var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy });
return op.output; return op.output;
} }




+ 3
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -335,9 +335,10 @@ namespace Tensorflow.Keras.Engine


var layer_inputs = node.MapArguments(tensor_dict); var layer_inputs = node.MapArguments(tensor_dict);


tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}");
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, is_training: training); var outputs = node.Layer.Apply(layer_inputs, is_training: training);

foreach (var output in outputs.Where(x => x != null))
tf.Logger.Debug($"{output.TensorShape}");
// Update tensor_dict for next input // Update tensor_dict for next input
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));


+ 0
- 3
src/TensorFlowNET.Keras/KerasApi.cs View File

@@ -4,9 +4,6 @@ namespace Tensorflow
{ {
public static class KerasApi public static class KerasApi
{ {
public static KerasInterface Keras(this tensorflow tf)
=> new KerasInterface();

public static KerasInterface keras { get; } = new KerasInterface(); public static KerasInterface keras { get; } = new KerasInterface();
} }
} }

Loading…
Cancel
Save