|
|
@@ -14,6 +14,7 @@ |
|
|
|
limitations under the License. |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using System.Linq; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
@@ -36,5 +37,24 @@ namespace Tensorflow |
|
|
|
return t_max; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Computes the global norm of multiple tensors. |
|
|
|
/// </summary> |
|
|
|
/// <param name="t_list"></param> |
|
|
|
/// <param name="name"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public static Tensor global_norm(Tensor[] t_list, string name = null) |
|
|
|
{ |
|
|
|
return tf_with(ops.name_scope(name, "global_norm", t_list), delegate |
|
|
|
{ |
|
|
|
var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray(); |
|
|
|
var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)); |
|
|
|
var norm = math_ops.sqrt(half_squared_norm * |
|
|
|
constant_op.constant(2.0, dtype: half_squared_norm.dtype), |
|
|
|
name: "global_norm"); |
|
|
|
return norm; |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
} |