Browse Source

Add a tolorance to equivalence of NDArray.

pull/1047/head
Yaohui Liu 2 years ago
parent
commit
49a902f6ee
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
3 changed files with 21 additions and 4 deletions
  1. +19
    -2
      src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
  2. +1
    -1
      src/TensorflowNET.Hub/Tensorflow.Hub.csproj
  3. +1
    -1
      test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj

+ 19
- 2
src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs View File

@@ -33,7 +33,16 @@ namespace Tensorflow.NumPy
return Scalar(false);
if(rhs is null)
return Scalar(false);
return new NDArray(math_ops.equal(lhs, rhs));
// TODO(Rinne): use np.allclose instead.
if (lhs.dtype.is_floating() || rhs.dtype.is_floating())
{
var diff = tf.abs(lhs - rhs);
return new NDArray(gen_math_ops.less(diff, new NDArray(1e-5).astype(diff.dtype)));
}
else
{
return new NDArray(math_ops.equal(lhs, rhs));
}
}
[AutoNumPy]
public static NDArray operator !=(NDArray lhs, NDArray rhs)
@@ -42,7 +51,15 @@ namespace Tensorflow.NumPy
return Scalar(false);
if(lhs is null || rhs is null)
return Scalar(true);
return new NDArray(math_ops.not_equal(lhs, rhs));
if (lhs.dtype.is_floating() || rhs.dtype.is_floating())
{
var diff = tf.abs(lhs - rhs);
return new NDArray(gen_math_ops.greater_equal(diff, new NDArray(1e-5).astype(diff.dtype)));
}
else
{
return new NDArray(math_ops.not_equal(lhs, rhs));
}
}
}
}

+ 1
- 1
src/TensorflowNET.Hub/Tensorflow.Hub.csproj View File

@@ -6,7 +6,7 @@
<Nullable>enable</Nullable>
<Version>1.0.0</Version>
<PackageId>TensorFlow.NET.Hub</PackageId>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageLicenseExpression>Apache2.0</PackageLicenseExpression>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>
<Authors>Yaohui Liu, Haiping Chen</Authors>


+ 1
- 1
test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6</TargetFramework>


Loading…
Cancel
Save