Browse Source

add avg_pool_grad function

tags/v0.110.4-Transformer-Model
dogvane 2 years ago
parent
commit
b968fd79ab
1 changed files with 17 additions and 0 deletions
  1. +17
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs

+ 17
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -365,6 +365,23 @@ namespace Tensorflow.Gradients
};
}

[RegisterGradient("AvgPool")]
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads)
{
Tensor grad = grads[0];
return new Tensor[]
{
gen_nn_ops.avg_pool_grad(
array_ops.shape(op.inputs[0]),
grad,
op.get_attr_list<int>("ksize"),
op.get_attr_list<int>("strides"),
op.get_attr("padding").ToString(),
op.get_attr("data_format").ToString())
};
}

/// <summary>
/// Return the gradients for TopK.
/// </summary>


Loading…
Cancel
Save