diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 6a186a63..1c28f124 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -29,6 +29,10 @@ namespace Tensorflow /// A convenient alias for None, useful for indexing arrays. /// public Slice newaxis = Slice.NewAxis; + /// + /// A convenient alias for ... + /// + public Slice ellipsis = Slice.Ellipsis; /// /// BatchToSpace for N-D tensors of type T. diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index 92013c13..9d836560 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -209,7 +209,7 @@ namespace Tensorflow name, sorted_input, canonicalized_coordinates, tile_size); public Tensor resize(Tensor image, TensorShape size) - => image_ops_impl.resize_images(image, tf.constant(size)); + => image_ops_impl.resize_images_v2(image, size); public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, half_pixel_centers: half_pixel_centers, name: name); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index cf66ca48..084a8cb3 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -19,6 +19,12 @@ namespace Tensorflow.Eager public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); + public override void set_shape(TensorShape shape) + { + if (!shape.is_compatible_with(this.shape)) + throw new ValueError($"Tensor's shape is not compatible."); + } + public static int GetRank(IntPtr handle) { var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 86434bfc..947a7b2e 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -672,10 +672,8 @@ or rank = 4. Had rank = {0}", rank)); internal static Tensor _resize_images_common(Tensor images, Func resizer_fn, Tensor size, bool preserve_aspect_ratio, string name, bool skip_resize_if_same) { - using (ops.name_scope(name, "resize", new [] {images, size})) - return tf_with(ops.name_scope(name, "resize", new [] {images, size}), delegate + return tf_with(ops.name_scope(name, "resize", new[] {images, size}), delegate { - images = ops.convert_to_tensor(images, name: "images"); if (images.TensorShape.ndim == Unknown) throw new ValueError("\'images\' contains no shape."); bool is_batch = true; @@ -688,18 +686,6 @@ or rank = 4. Had rank = {0}", rank)); var (height, width) = (images.dims[1], images.dims[2]); - try - { - size = ops.convert_to_tensor(size, dtypes.int32, name: "size"); - } - catch (Exception ex) - { - if (ex is TypeError || ex is ValueError) - throw new ValueError("\'size\' must be a 1-D int32 Tensor"); - else - throw; - } - if (!size.TensorShape.is_compatible_with(new [] {2})) throw new ValueError(@"\'size\' must be a 1-D Tensor of 2 elements: new_height, new_width"); @@ -756,7 +742,7 @@ new_height, new_width"); images = resizer_fn(images, size); - // images.set_shape(new TensorShape(new int[] { -1, new_height_const, new_width_const, -1 })); + images.set_shape(new TensorShape(new int[] { Unknown, new_height_const, new_width_const, Unknown })); if (!is_batch) images = array_ops.squeeze(images, axis: new int[] {0}); @@ -2163,6 +2149,33 @@ new_height, new_width"); }); } + /// + /// Resize `images` to `size` using the specified `method`. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor resize_images_v2(Tensor images, TensorShape size, string method = ResizeMethod.BILINEAR, + bool preserve_aspect_ratio = false, + bool antialias = false, + string name = null) + { + Func resize_fn = (images, size) => + { + if (method == ResizeMethod.BILINEAR) + return gen_image_ops.resize_bilinear(images, size, half_pixel_centers: true); + throw new NotImplementedException(""); + }; + return _resize_images_common(images, resize_fn, ops.convert_to_tensor(size), + preserve_aspect_ratio: preserve_aspect_ratio, + skip_resize_if_same: false, + name: name); + } + /// /// Resize `images` to `size` using nearest neighbor interpolation. /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index b1b6700d..926524e9 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -149,7 +149,7 @@ namespace Tensorflow /// /// Updates the shape of this tensor. /// - public void set_shape(TensorShape shape) + public virtual void set_shape(TensorShape shape) { this.shape = shape.rank >= 0 ? shape.dims : null; } diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index 47ec6a7c..f056084d 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -4,6 +4,7 @@ using NumSharp; using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Reflection; using System.Text; using Tensorflow; @@ -35,9 +36,10 @@ namespace TensorFlowNET.UnitTest.Basics Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); } - [TestMethod, Ignore] + [TestMethod] public void resize_image() { + tf.enable_eager_execution(); var image = tf.constant(new int[5, 5] { {1, 0, 0, 0, 0 }, @@ -46,10 +48,16 @@ namespace TensorFlowNET.UnitTest.Basics {0, 0, 0, 1, 0 }, {0, 0, 0, 0, 1 } }); - //image = image[tf.newaxis, ..., tf.newaxis]; - - var img = tf.image.resize(contents, (3, 5)); - Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); + image = image[tf.newaxis, tf.ellipsis, tf.newaxis]; + image = tf.image.resize(image, (3, 5)); + image = image[0, tf.ellipsis, 0]; + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0.6666667f, 0.3333333f, 0, 0, 0 }, + image[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0, 0, 1, 0, 0 }, + image[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new float[] { 0, 0, 0, 0.3333335f, 0.6666665f }, + image[2].ToArray())); + tf.compat.v1.disable_eager_execution(); } [TestMethod]