Browse Source

add tf.linalg.global_norm #857

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
a4a99bb08f
5 changed files with 40 additions and 1 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Operations/clip_ops.cs
  3. +8
    -0
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  4. +0
    -1
      src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
  5. +9
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -51,6 +51,9 @@ namespace Tensorflow
public Tensor inv(Tensor input, bool adjoint = false, string name = null)
=> ops.matrix_inverse(input, adjoint: adjoint, name: name);

public Tensor global_norm(Tensor[] t_list, string name = null)
=> clip_ops.global_norm(t_list, name: name);

public Tensor lstsq(Tensor matrix, Tensor rhs,
NDArray l2_regularizer = null, bool fast = true, string name = null)
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);


+ 20
- 0
src/TensorFlowNET.Core/Operations/clip_ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -36,5 +37,24 @@ namespace Tensorflow
return t_max;
});
}

/// <summary>
/// Computes the global norm of multiple tensors.
/// </summary>
/// <param name="t_list"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor global_norm(Tensor[] t_list, string name = null)
{
return tf_with(ops.name_scope(name, "global_norm", t_list), delegate
{
var half_squared_norms = t_list.Select(v => nn_ops.l2_loss(v)).ToArray();
var half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms));
var norm = math_ops.sqrt(half_squared_norm *
constant_op.constant(2.0, dtype: half_squared_norm.dtype),
name: "global_norm");
return norm;
});
}
}
}

+ 8
- 0
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -75,11 +75,19 @@ namespace Tensorflow.Keras.Engine
metric_obj = keras.metrics.sparse_categorical_accuracy;
else
metric_obj = keras.metrics.categorical_accuracy;

metric = "accuracy";
}
else if(metric == "mean_absolute_error" || metric == "mae")
{
metric_obj = keras.metrics.mean_absolute_error;
metric = "mean_absolute_error";
}
else if (metric == "mean_absolute_percentage_error" || metric == "mape")
{
metric_obj = keras.metrics.mean_absolute_percentage_error;
metric = "mean_absolute_percentage_error";
}
else
throw new NotImplementedException("");



+ 0
- 1
src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs View File

@@ -4,7 +4,6 @@ namespace Tensorflow.Keras.Metrics
{
public class MeanMetricWrapper : Mean
{
string name;
Func<Tensor, Tensor, Tensor> _fn = null;

public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT)


+ 9
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ManagedAPI
@@ -54,5 +55,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var e = tf.linalg.einsum("ij,jk->ik", (m0, m1));
Assert.AreEqual(e.shape, (2, 5));
}

[TestMethod]
public void GlobalNorm()
{
var t_list = new Tensors(tf.constant(new float[] { 1, 2, 3, 4 }), tf.constant(new float[] { 5, 6, 7, 8 }));
var norm = tf.linalg.global_norm(t_list);
Assert.AreEqual(norm.numpy(), 14.282857f);
}
}
}

Loading…
Cancel
Save