Browse Source

fix test

pull/1215/head
Alexander 1 year ago
parent
commit
7968dc360f
1 changed files with 13 additions and 3 deletions
  1. +13
    -3
      test/Tensorflow.UnitTest/PythonTest.cs

+ 13
- 3
test/Tensorflow.UnitTest/PythonTest.cs View File

@@ -133,13 +133,23 @@ namespace TensorFlowNET.UnitTest


public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5) public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5)
{ {
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps));

//TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
} }


public void assertAllClose(double value, NDArray array2, double eps = 1e-5) public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
{ {
if (array2.shape.IsScalar)
{
double value2 = array2;
Assert.AreEqual(value, value2, eps);
return;
}
var array1 = np.ones_like(array2) * value; var array1 = np.ones_like(array2) * value;
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps));

//TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
} }


private class CollectionComparer : IComparer private class CollectionComparer : IComparer
@@ -158,7 +168,7 @@ namespace TensorFlowNET.UnitTest
} }
else if (x == null) else if (x == null)
{ {
return -1;
return -1;
} }
else if (y == null) else if (y == null)
{ {


Loading…
Cancel
Save