Browse Source

unfinished gradients. compile failed.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
31fbe4f026
5 changed files with 109 additions and 13 deletions
  1. +72
    -11
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  3. +6
    -0
      src/TensorFlowNET.Core/Operations/InputList.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +6
    -2
      src/TensorFlowNET.Core/ops.py.cs

+ 72
- 11
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -46,6 +46,8 @@ namespace Tensorflow
all.AddRange(stop_gradients);
all.AddRange(grad_ys);

var grads = new Dictionary<string, object>();

Python.with<ops.name_scope>(new ops.name_scope(name, "gradients", values: all), scope =>
{
string grad_scope = scope;
@@ -78,7 +80,7 @@ namespace Tensorflow
* aggregate the list of received gradients into a Add() Operation if there
* is more than one.
**/
var grads = new Dictionary<string, Tensor[][]>();
for(int i = 0; i < ys.Count(); i++)
{
(var y, var grad_y) = Python.zip(ys, grad_ys, i);
@@ -111,6 +113,7 @@ namespace Tensorflow
//loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);

Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false;
var has_out_grads = true;
@@ -124,13 +127,60 @@ namespace Tensorflow
{
// A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients.
var grad_fn = ops.get_gradient_function(op);

Python.with<ops.name_scope>(new ops.name_scope(op.Name + "_grad"), delegate
{
if (grad_fn != null)
{
in_grads = _MaybeCompile(grad_scope, op, out_grads[0], null, grad_fn);
_VerifyGeneratedGradients(in_grads, op);
}
});
}
}

for(int i =0; i< in_grads.Length; i++)
{
var inputs = (List<Tensor>)_NonEagerInputs(op, xs);
var (t_in, in_grad) = Python.zip(inputs, in_grads, i);
if(in_grad != null)
{
in_grad.shape = t_in.shape;
_SetGrad(grads, t_in, in_grad);
}
}

// Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
}
});

return null;
return xs.Select(x => _GetGrad(grads, x)).ToArray();
}

private static void _UpdatePendingAndEnqueueReady(Dictionary<string, Tensor[][]> grads,
Operation op,
Queue<Operation> queue,
Dictionary<string ,int> pending_count,
object loop_state,
Tensor[] xs)
{

}

private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
{
if (grads.Count() != op.inputs._inputs.Count())
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}");
}

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, (Tensor, Tensor)> grad_fn)
{
var in_grads = grad_fn(op, out_grads);
return new Tensor[] { in_grads.Item1, in_grads.Item2 };
}

private static bool _IsPartitionedCall(Operation op)
@@ -138,9 +188,9 @@ namespace Tensorflow
return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall";
}

private static Tensor[] _AggregatedGrads(Dictionary<string, Tensor[][]> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
private static Tensor[] _AggregatedGrads(Dictionary<string, object> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
{
var out_grads = _GetGrads(grads, op);
var out_grads = _GetGrads(grads, op) as object[];
for(int i = 0; i < out_grads.Length; i++)
{
var out_grad = out_grads[i];
@@ -195,12 +245,22 @@ namespace Tensorflow
return stop_ops.ToArray();
}

private static Tensor[][] _GetGrads(Dictionary<string, Tensor[][]> grads, Operation op)
private static Tensor _GetGrad(Dictionary<string, Tensor[][]> grads, Tensor t)
{
var op = t.op;
if (!grads.ContainsKey(op.Name))
return null;
Tensor[][] op_grads = grads[op.Name];
var t_grad = op_grads[t.value_index];
return t_grad[0];
}

private static object _GetGrads(Dictionary<string, object> grads, Operation op)
{
if (grads.ContainsKey(op.Name))
return grads[op.Name];
else
return op.outputs.Select(x => new Tensor[0]).ToArray();
return op.outputs.Select(x => new object[0]).ToArray();
}

/// <summary>
@@ -209,17 +269,17 @@ namespace Tensorflow
/// <param name="grads"></param>
/// <param name="t"></param>
/// <param name="grad"></param>
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad)
private static void _SetGrad(Dictionary<string, object> grads, Tensor t, Tensor grad)
{
var op = t.op;
Tensor[][] op_grads = null;
object op_grads = null;
if (!grads.ContainsKey(op.Name))
{
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray();
op_grads = op.outputs.Select(x => new object[1]).ToList();
grads[op.Name] = op_grads;
}
var t_grads = op_grads[t.value_index];
t_grads[0] = grad;
var t_grads = (op_grads as object[])[t.value_index];
// t_grads[0] = grad;
}

/// <summary>
@@ -322,6 +382,7 @@ namespace Tensorflow
{
return op.inputs;
}

/// <summary>
/// Mark all ops reached from "from_ops"
/// </summary>


+ 20
- 0
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Gradients for operators defined in math_ops.py.
/// </summary>
public class math_grad
{
public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad)
{
var x = op.inputs[0];
var y = op.inputs[1];

return (grad, grad);
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Operations/InputList.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
@@ -19,5 +20,10 @@ namespace Tensorflow
{
return _inputs.GetEnumerator();
}

public static implicit operator List<Tensor>(InputList input)
{
return input._inputs.ToList();
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -57,6 +57,11 @@ namespace Tensorflow

return dims;
}

set
{
// c_api.TF_GraphSetTensorShape_wrapper
}
}
/// <summary>


+ 6
- 2
src/TensorFlowNET.Core/ops.py.cs View File

@@ -279,10 +279,14 @@ namespace Tensorflow
return tf.Session();
}

public static object get_gradient_function(Operation op)
public static Func<Operation, Tensor, (Tensor, Tensor)> get_gradient_function(Operation op)
{
if (op.inputs == null) return null;
return null;

return (oper, out_grads) =>
{
return math_grad._AddGrad(op, out_grads);
};
}
}
}

Loading…
Cancel
Save