diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index c0d0c5a6..c9693f05 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -730,7 +730,7 @@ namespace Tensorflow.Operations /// public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle") { - return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(new object[] { input })); + return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout })); } /// @@ -4971,13 +4971,16 @@ namespace Tensorflow.Operations /// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] /// /// - public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex") + public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex") { - return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(new object[] { real, imag })); // sorry, cannot pass Tout, so it only works with complex64. complex128 is not supported yet + TF_DataType Tin = real.GetDataType(); + if (a_Tout is null) + { + a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64; + } + return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout })); } - - /// /// Computes the complex absolute value of a tensor. /// @@ -4999,7 +5002,7 @@ namespace Tensorflow.Operations /// public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs") { - return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(new object[] { x })); + return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout })); } /// @@ -13308,9 +13311,12 @@ namespace Tensorflow.Operations /// tf.imag(input) ==> [4.75, 5.75] /// /// - public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag") + public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag") { - return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input })); + TF_DataType Tin = input.GetDataType(); + return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); + + // return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input })); } /// @@ -23841,9 +23847,12 @@ namespace Tensorflow.Operations /// tf.real(input) ==> [-2.25, 3.25] /// /// - public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real") + public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real") { - return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input})); + TF_DataType Tin = input.GetDataType(); + return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); + +// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input})); } /// diff --git a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs index c9b05e61..a57ec929 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs @@ -13,7 +13,8 @@ namespace TensorFlowNET.UnitTest.Basics [TestClass] public class ComplexTest : EagerModeTestBase { - [Ignore("Not working")] + // Tests for Complex128 + [TestMethod] public void complex128_basic() { @@ -23,7 +24,7 @@ namespace TensorFlowNET.UnitTest.Basics Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE); Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); - Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); + Tensor t_complex = tf.complex(t_real, t_imag); Tensor t_real_result = tf.math.real(t_complex); Tensor t_imag_result = tf.math.imag(t_complex); @@ -34,9 +35,77 @@ namespace TensorFlowNET.UnitTest.Basics double[] d_real_result =n_real_result.ToArray(); double[] d_imag_result = n_imag_result.ToArray(); - Assert.AreEqual(d_real_result, d_real); - Assert.AreEqual(d_imag_result, d_imag); + Assert.IsTrue(base.Equal(d_real_result, d_real)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag)); + } + [TestMethod] + public void complex128_abs() + { + tf.enable_eager_execution(); + + double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; + + double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag); + + Tensor t_abs_result = tf.abs(t_complex); + + double[] d_abs_result = t_abs_result.numpy().ToArray(); + Assert.IsTrue(base.Equal(d_abs_result, d_abs)); + } + [TestMethod] + public void complex128_conj() + { + double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; + + double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 }; + double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); + + Tensor t_result = tf.math.conj(t_complex); + + NDArray n_real_result = tf.math.real(t_result).numpy(); + NDArray n_imag_result = tf.math.imag(t_result).numpy(); + + double[] d_real_result = n_real_result.ToArray(); + double[] d_imag_result = n_imag_result.ToArray(); + + Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); + Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); + } + [TestMethod] + public void complex128_angle() + { + double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 }; + double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 }; + + double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 }; + + Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); + Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); + + Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); + + Tensor t_result = tf.math.angle(t_complex); + + NDArray n_result = t_result.numpy(); + + double[] d_result = n_result.ToArray(); + + Assert.IsTrue(base.Equal(d_result, d_expected)); } + + // Tests for Complex64 [TestMethod] public void complex64_basic() { @@ -47,7 +116,7 @@ namespace TensorFlowNET.UnitTest.Basics Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); - Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); + Tensor t_complex = tf.complex(t_real, t_imag); Tensor t_real_result = tf.math.real(t_complex); Tensor t_imag_result = tf.math.imag(t_complex);