Browse Source

assert all close

pull/1215/head
Alexander 1 year ago
parent
commit
165e9169e4
2 changed files with 10 additions and 30 deletions
  1. +1
    -21
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  2. +9
    -9
      test/Tensorflow.UnitTest/PythonTest.cs

+ 1
- 21
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -625,25 +625,6 @@ namespace TensorFlowNET.UnitTest.Gradient
}
}

// TODO: remove when np.testing.assert_allclose(a, b) is implemented
private class CollectionComparer : System.Collections.IComparer
{
private readonly double _epsilon = 1e-07;

public int Compare(object x, object y)
{
var a = (double)x;
var b = (double)y;

double delta = Math.Abs(a - b);
if (delta < _epsilon)
{
return 0;
}
return a.CompareTo(b);
}
}

private struct Case
{
public Tensor[] grad1;
@@ -748,8 +729,7 @@ namespace TensorFlowNET.UnitTest.Gradient
var npgrad2 = result[1];
foreach (var (a, b) in npgrad1.Zip(npgrad2))
{
// TODO: np.testing.assert_allclose(a, b);
CollectionAssert.AreEqual(a.ToArray(), b.ToArray(), new CollectionComparer());
self.assertAllClose(a, b);
}
}
}


+ 9
- 9
test/Tensorflow.UnitTest/PythonTest.cs View File

@@ -185,9 +185,9 @@ namespace TensorFlowNET.UnitTest

#region tensor evaluation and test session

private Session _cached_session = null;
private Graph _cached_graph = null;
private object _cached_config = null;
private Session? _cached_session = null;
private Graph? _cached_graph = null;
private object? _cached_config = null;
private bool _cached_force_gpu = false;

private void _ClearCachedSession()
@@ -237,7 +237,7 @@ namespace TensorFlowNET.UnitTest
/// </summary>
public T evaluate<T>(Tensor tensor)
{
object result = null;
object? result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
@@ -274,7 +274,7 @@ namespace TensorFlowNET.UnitTest

///Returns a TensorFlow Session for use in executing tests.
public Session cached_session(
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false)
{
// This method behaves differently than self.session(): for performance reasons
// `cached_session` will by default reuse the same session within the same
@@ -325,7 +325,7 @@ namespace TensorFlowNET.UnitTest
}

//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
public Session session(Graph? graph = null, object? config = null, bool use_gpu = false, bool force_gpu = false)
{
//Note that this will set this session and the graph as global defaults.

@@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest
// A Session object that should be used as a context manager to surround
// the graph building and execution code in a test case.

Session s = null;
Session? s = null;
//if (context.executing_eagerly())
// yield None
//else
@@ -448,8 +448,8 @@ namespace TensorFlowNET.UnitTest
}

private Session _get_cached_session(
Graph graph = null,
object config = null,
Graph? graph = null,
object? config = null,
bool force_gpu = false,
bool crash_if_inconsistent_args = true)
{


Loading…
Cancel
Save