Browse Source

Added support for Complex128 and unit tests for it.

tags/v0.100.5-BERT-load
BalashovK 2 years ago
parent
commit
febba7b354
2 changed files with 93 additions and 15 deletions
  1. +19
    -10
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  2. +74
    -5
      test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs

+ 19
- 10
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -730,7 +730,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
{
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(new object[] { input }));
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
}

/// <summary>
@@ -4971,13 +4971,16 @@ namespace Tensorflow.Operations
/// tf.complex(real, imag) ==&amp;gt; [[2.25 + 4.75j], [3.25 + 5.75j]]
/// </code>
/// </remarks>
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
{
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(new object[] { real, imag })); // sorry, cannot pass Tout, so it only works with complex64. complex128 is not supported yet
TF_DataType Tin = real.GetDataType();
if (a_Tout is null)
{
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
}
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
}



/// <summary>
/// Computes the complex absolute value of a tensor.
/// </summary>
@@ -4999,7 +5002,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
{
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(new object[] { x }));
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
}

/// <summary>
@@ -13308,9 +13311,12 @@ namespace Tensorflow.Operations
/// tf.imag(input) ==&amp;gt; [4.75, 5.75]
/// </code>
/// </remarks>
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
{
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
}

/// <summary>
@@ -23841,9 +23847,12 @@ namespace Tensorflow.Operations
/// tf.real(input) ==&amp;gt; [-2.25, 3.25]
/// </code>
/// </remarks>
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
{
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
}

/// <summary>


+ 74
- 5
test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs View File

@@ -13,7 +13,8 @@ namespace TensorFlowNET.UnitTest.Basics
[TestClass]
public class ComplexTest : EagerModeTestBase
{
[Ignore("Not working")]
// Tests for Complex128

[TestMethod]
public void complex128_basic()
{
@@ -23,7 +24,7 @@ namespace TensorFlowNET.UnitTest.Basics
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);
@@ -34,9 +35,77 @@ namespace TensorFlowNET.UnitTest.Basics
double[] d_real_result =n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.AreEqual(d_real_result, d_real);
Assert.AreEqual(d_imag_result, d_imag);
Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void complex128_abs()
{
tf.enable_eager_execution();

double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_abs_result = tf.abs(t_complex);

double[] d_abs_result = t_abs_result.numpy().ToArray<double>();
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
}
[TestMethod]
public void complex128_conj()
{
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.conj(t_complex);

NDArray n_real_result = tf.math.real(t_result).numpy();
NDArray n_imag_result = tf.math.imag(t_result).numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
}
[TestMethod]
public void complex128_angle()
{
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };

double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.angle(t_complex);

NDArray n_result = t_result.numpy();

double[] d_result = n_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_result, d_expected));
}

// Tests for Complex64
[TestMethod]
public void complex64_basic()
{
@@ -47,7 +116,7 @@ namespace TensorFlowNET.UnitTest.Basics
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);


Loading…
Cancel
Save