Browse Source

unfinished gradients. compile failed.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
31fbe4f026
5 changed files with 109 additions and 13 deletions
  1. +72
    -11
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  3. +6
    -0
      src/TensorFlowNET.Core/Operations/InputList.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +6
    -2
      src/TensorFlowNET.Core/ops.py.cs

+ 72
- 11
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -46,6 +46,8 @@ namespace Tensorflow
all.AddRange(stop_gradients); all.AddRange(stop_gradients);
all.AddRange(grad_ys); all.AddRange(grad_ys);


var grads = new Dictionary<string, object>();

Python.with<ops.name_scope>(new ops.name_scope(name, "gradients", values: all), scope => Python.with<ops.name_scope>(new ops.name_scope(name, "gradients", values: all), scope =>
{ {
string grad_scope = scope; string grad_scope = scope;
@@ -78,7 +80,7 @@ namespace Tensorflow
* aggregate the list of received gradients into a Add() Operation if there * aggregate the list of received gradients into a Add() Operation if there
* is more than one. * is more than one.
**/ **/
var grads = new Dictionary<string, Tensor[][]>();
for(int i = 0; i < ys.Count(); i++) for(int i = 0; i < ys.Count(); i++)
{ {
(var y, var grad_y) = Python.zip(ys, grad_ys, i); (var y, var grad_y) = Python.zip(ys, grad_ys, i);
@@ -111,6 +113,7 @@ namespace Tensorflow
//loop_state.EnterGradWhileContext(op, before: true); //loop_state.EnterGradWhileContext(op, before: true);
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);


Tensor[] in_grads = null;
var is_partitioned_call = _IsPartitionedCall(op); var is_partitioned_call = _IsPartitionedCall(op);
var is_func_call = false; var is_func_call = false;
var has_out_grads = true; var has_out_grads = true;
@@ -124,13 +127,60 @@ namespace Tensorflow
{ {
// A grad_fn must be defined, either as a function or as None // A grad_fn must be defined, either as a function or as None
// for ops that do not have gradients. // for ops that do not have gradients.
var grad_fn = ops.get_gradient_function(op);

Python.with<ops.name_scope>(new ops.name_scope(op.Name + "_grad"), delegate
{
if (grad_fn != null)
{
in_grads = _MaybeCompile(grad_scope, op, out_grads[0], null, grad_fn);
_VerifyGeneratedGradients(in_grads, op);
}
});
}
}


for(int i =0; i< in_grads.Length; i++)
{
var inputs = (List<Tensor>)_NonEagerInputs(op, xs);
var (t_in, in_grad) = Python.zip(inputs, in_grads, i);
if(in_grad != null)
{
in_grad.shape = t_in.shape;
_SetGrad(grads, t_in, in_grad);
} }
} }

// Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
} }
}); });


return null;
return xs.Select(x => _GetGrad(grads, x)).ToArray();
}

private static void _UpdatePendingAndEnqueueReady(Dictionary<string, Tensor[][]> grads,
Operation op,
Queue<Operation> queue,
Dictionary<string ,int> pending_count,
object loop_state,
Tensor[] xs)
{

}

private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
{
if (grads.Count() != op.inputs._inputs.Count())
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " +
$"inputs {op.inputs._inputs.Count()}");
}

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, (Tensor, Tensor)> grad_fn)
{
var in_grads = grad_fn(op, out_grads);
return new Tensor[] { in_grads.Item1, in_grads.Item2 };
} }


private static bool _IsPartitionedCall(Operation op) private static bool _IsPartitionedCall(Operation op)
@@ -138,9 +188,9 @@ namespace Tensorflow
return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall";
} }


private static Tensor[] _AggregatedGrads(Dictionary<string, Tensor[][]> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
private static Tensor[] _AggregatedGrads(Dictionary<string, object> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
{ {
var out_grads = _GetGrads(grads, op);
var out_grads = _GetGrads(grads, op) as object[];
for(int i = 0; i < out_grads.Length; i++) for(int i = 0; i < out_grads.Length; i++)
{ {
var out_grad = out_grads[i]; var out_grad = out_grads[i];
@@ -195,12 +245,22 @@ namespace Tensorflow
return stop_ops.ToArray(); return stop_ops.ToArray();
} }


private static Tensor[][] _GetGrads(Dictionary<string, Tensor[][]> grads, Operation op)
private static Tensor _GetGrad(Dictionary<string, Tensor[][]> grads, Tensor t)
{
var op = t.op;
if (!grads.ContainsKey(op.Name))
return null;
Tensor[][] op_grads = grads[op.Name];
var t_grad = op_grads[t.value_index];
return t_grad[0];
}

private static object _GetGrads(Dictionary<string, object> grads, Operation op)
{ {
if (grads.ContainsKey(op.Name)) if (grads.ContainsKey(op.Name))
return grads[op.Name]; return grads[op.Name];
else else
return op.outputs.Select(x => new Tensor[0]).ToArray();
return op.outputs.Select(x => new object[0]).ToArray();
} }


/// <summary> /// <summary>
@@ -209,17 +269,17 @@ namespace Tensorflow
/// <param name="grads"></param> /// <param name="grads"></param>
/// <param name="t"></param> /// <param name="t"></param>
/// <param name="grad"></param> /// <param name="grad"></param>
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad)
private static void _SetGrad(Dictionary<string, object> grads, Tensor t, Tensor grad)
{ {
var op = t.op; var op = t.op;
Tensor[][] op_grads = null;
object op_grads = null;
if (!grads.ContainsKey(op.Name)) if (!grads.ContainsKey(op.Name))
{ {
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray();
op_grads = op.outputs.Select(x => new object[1]).ToList();
grads[op.Name] = op_grads; grads[op.Name] = op_grads;
} }
var t_grads = op_grads[t.value_index];
t_grads[0] = grad;
var t_grads = (op_grads as object[])[t.value_index];
// t_grads[0] = grad;
} }


/// <summary> /// <summary>
@@ -322,6 +382,7 @@ namespace Tensorflow
{ {
return op.inputs; return op.inputs;
} }

/// <summary> /// <summary>
/// Mark all ops reached from "from_ops" /// Mark all ops reached from "from_ops"
/// </summary> /// </summary>


+ 20
- 0
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Gradients for operators defined in math_ops.py.
/// </summary>
public class math_grad
{
public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad)
{
var x = op.inputs[0];
var y = op.inputs[1];

return (grad, grad);
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Operations/InputList.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -19,5 +20,10 @@ namespace Tensorflow
{ {
return _inputs.GetEnumerator(); return _inputs.GetEnumerator();
} }

public static implicit operator List<Tensor>(InputList input)
{
return input._inputs.ToList();
}
} }
} }

+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -57,6 +57,11 @@ namespace Tensorflow


return dims; return dims;
} }

set
{
// c_api.TF_GraphSetTensorShape_wrapper
}
} }
/// <summary> /// <summary>


+ 6
- 2
src/TensorFlowNET.Core/ops.py.cs View File

@@ -279,10 +279,14 @@ namespace Tensorflow
return tf.Session(); return tf.Session();
} }


public static object get_gradient_function(Operation op)
public static Func<Operation, Tensor, (Tensor, Tensor)> get_gradient_function(Operation op)
{ {
if (op.inputs == null) return null; if (op.inputs == null) return null;
return null;

return (oper, out_grads) =>
{
return math_grad._AddGrad(op, out_grads);
};
} }
} }
} }

Loading…
Cancel
Save