Browse Source

修正图片左右和上下翻转的问题,并增加对应测试用例。

tags/v0.150.0-BERT-Model
dogvane 2 years ago
parent
commit
5e4f53077f
9 changed files with 525 additions and 12 deletions
  1. BIN
      data/img001.bmp
  2. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.image.cs
  3. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.io.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  5. +32
    -11
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  6. +22
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  7. +90
    -0
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
  8. +317
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
  9. +44
    -0
      test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs

BIN
data/img001.bmp View File

Before After
Width: 244  |  Height: 244  |  Size: 179 kB

+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -339,6 +339,13 @@ namespace Tensorflow
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
name: name, expand_animations: expand_animations); name: name, expand_animations: expand_animations);


public Tensor encode_png(Tensor contents, string name = null)
=> image_ops_impl.encode_png(contents, name: name);

public Tensor encode_jpeg(Tensor contents, string name = null)
=> image_ops_impl.encode_jpeg(contents, name: name);


/// <summary> /// <summary>
/// Convenience function to check if the 'contents' encodes a JPEG image. /// Convenience function to check if the 'contents' encodes a JPEG image.
/// </summary> /// </summary>


+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -16,6 +16,7 @@


using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.IO; using Tensorflow.IO;
using Tensorflow.Operations;


namespace Tensorflow namespace Tensorflow
{ {
@@ -46,6 +47,12 @@ namespace Tensorflow
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null) string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);

public Operation write_file(string filename, Tensor conentes, string name = null)
=> write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name);

public Operation write_file(Tensor filename, Tensor conentes, string name = null)
=> gen_ops.write_file(filename, conentes, name);
} }


public GFile gfile = new GFile(); public GFile gfile = new GFile();


+ 6
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -55,6 +55,12 @@ namespace Tensorflow.Keras.Layers
string kernel_initializer = "glorot_uniform", string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros"); string bias_initializer = "zeros");


public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid"
);

public ILayer Conv2D(int filters, public ILayer Conv2D(int filters,
Shape kernel_size = null, Shape kernel_size = null,
Shape strides = null, Shape strides = null,


+ 32
- 11
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -102,7 +102,10 @@ namespace Tensorflow
{ {
throw new ValueError("\'image\' must be fully defined."); throw new ValueError("\'image\' must be fully defined.");
} }
var dims = image_shape["-3:"];
var dims = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 1]});
foreach (var dim in dims.dims) foreach (var dim in dims.dims)
{ {
if (dim == 0) if (dim == 0)
@@ -112,16 +115,18 @@ namespace Tensorflow
} }


var image_shape_last_three_elements = new Shape(new[] { var image_shape_last_three_elements = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 1],
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2], image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 3]});
image_shape.dims[image_shape.dims.Length - 1]});
if (!image_shape_last_three_elements.IsFullyDefined) if (!image_shape_last_three_elements.IsFullyDefined)
{ {
Tensor image_shape_ = array_ops.shape(image); Tensor image_shape_ = array_ops.shape(image);
var image_shape_return = tf.constant(new[] {
image_shape_.dims[image_shape.dims.Length - 1],
image_shape_.dims[image_shape.dims.Length - 2],
image_shape_.dims[image_shape.dims.Length - 3]});
var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 });

//var image_shape_return = tf.constant(new[] {
// image_shape_.dims[image_shape_.dims.Length - 3],
// image_shape_.dims[image_shape_.dims.Length - 2],
// image_shape_.dims[image_shape_.dims.Length - 1]});


return new Operation[] { return new Operation[] {
check_ops.assert_positive( check_ops.assert_positive(
@@ -209,10 +214,10 @@ namespace Tensorflow
} }


public static Tensor flip_left_right(Tensor image) public static Tensor flip_left_right(Tensor image)
=> _flip(image, 0, "flip_left_right");
=> _flip(image, 1, "flip_left_right");


public static Tensor flip_up_down(Tensor image) public static Tensor flip_up_down(Tensor image)
=> _flip(image, 1, "flip_up_down");
=> _flip(image, 0, "flip_up_down");


internal static Tensor _flip(Tensor image, int flip_index, string scope_name) internal static Tensor _flip(Tensor image, int flip_index, string scope_name)
{ {
@@ -223,11 +228,11 @@ namespace Tensorflow
Shape shape = image.shape; Shape shape = image.shape;
if (shape.ndim == 3 || shape.ndim == Unknown) if (shape.ndim == 3 || shape.ndim == Unknown)
{ {
return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })));
return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index })));
} }
else if (shape.ndim == 4) else if (shape.ndim == 4)
{ {
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 }));
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 }));
} }
else else
{ {
@@ -2047,6 +2052,22 @@ new_height, new_width");
}); });
} }


public static Tensor encode_jpeg(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "encode_jpeg"), scope =>
{
return gen_ops.encode_jpeg(contents, name:name);
});
}

public static Tensor encode_png(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "encode_png"), scope =>
{
return gen_ops.encode_png(contents, name: name);
});
}

public static Tensor is_jpeg(Tensor contents, string name = null) public static Tensor is_jpeg(Tensor contents, string name = null)
{ {
return tf_with(ops.name_scope(name, "is_jpeg"), scope => return tf_with(ops.name_scope(name, "is_jpeg"), scope =>


+ 22
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -112,7 +112,28 @@ namespace Tensorflow.Keras.Layers
KernelInitializer = GetInitializerByName(kernel_initializer), KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer) BiasInitializer = GetInitializerByName(bias_initializer)
}); });

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid")
=> new Conv2D(new Conv2DArgs
{
Rank = 2,
Filters = filters,
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
Strides = strides == null ? (1, 1) : strides,
Padding = padding,
DataFormat = null,
DilationRate = (1, 1),
Groups = 1,
UseBias = false,
KernelRegularizer = null,
KernelInitializer =tf.glorot_uniform_initializer,
BiasInitializer = tf.zeros_initializer,
BiasRegularizer = null,
ActivityRegularizer = null,
Activation = keras.activations.Linear,
});
/// <summary> /// <summary>
/// 2D convolution layer (e.g. spatial convolution over images). /// 2D convolution layer (e.g. spatial convolution over images).
/// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs.


+ 90
- 0
test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -4,6 +4,7 @@ using System.Linq;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System; using System;
using System.IO;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -164,5 +165,94 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(result.size, 16ul); Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f); Assert.AreEqual(result[0, 0, 0, 0], 12f);
} }

[TestMethod]
public void ImageSaveTest()
{
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");
var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg");
var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png");

File.Delete(jpegImgPath);
File.Delete(pngImgPath);

var contents = tf.io.read_file(imgPath);
var bmp = tf.image.decode_image(contents);
Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0");

var jpeg = tf.image.encode_jpeg(bmp);
var op1 = tf.io.write_file(jpegImgPath, jpeg);

var png = tf.image.encode_png(bmp);
var op2 = tf.io.write_file(pngImgPath, png);

this.session().run(op1);
this.session().run(op2);

Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath);
Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath);

// 如果要测试图片正确性,需要注释下面两行代码
File.Delete(jpegImgPath);
File.Delete(pngImgPath);
}

[TestMethod]
public void ImageFlipTest()
{
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");

var contents = tf.io.read_file(imgPath);
var bmp = tf.image.decode_image(contents);

// 左右翻转
var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png");
File.Delete(lrImgPath);

var lr = tf.image.flip_left_right(bmp);
var png = tf.image.encode_png(lr);
var op = tf.io.write_file(lrImgPath, png);
this.session().run(op);

Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath);

// 上下翻转
var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png");
File.Delete(updownImgPath);

var updown = tf.image.flip_up_down(bmp);
var pngupdown = tf.image.encode_png(updown);
var op2 = tf.io.write_file(updownImgPath, pngupdown);
this.session().run(op2);
Assert.IsTrue(File.Exists(updownImgPath));


// 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码
File.Delete(lrImgPath);
File.Delete(updownImgPath);

// 多图翻转
// 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了
var mImg = tf.stack(new[] { bmp, lr }, axis:0);
print(mImg.shape);

var up2 = tf.image.flip_up_down(mImg);

var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转
File.Delete(updownImgPath_m1);

var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下
File.Delete(img001_updown_m2);

var png2 = tf.image.encode_png(up2[0]);
tf.io.write_file(updownImgPath_m1, png2);

png2 = tf.image.encode_png(up2[1]);
tf.io.write_file(img001_updown_m2, png2);

// 如果要测试图片正确性,需要注释下面两行代码
File.Delete(updownImgPath_m1);
File.Delete(img001_updown_m2);
}
} }
} }

+ 317
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Linq; using System.Linq;
using Tensorflow.Operations;


namespace TensorFlowNET.UnitTest.ManagedAPI namespace TensorFlowNET.UnitTest.ManagedAPI
{ {
@@ -105,5 +106,321 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>())); Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>()));
Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>())); Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>()));
} }

[TestMethod]
public void ReverseImgArray3D()
{
// 创建 sourceImg 数组
var sourceImgArray = new float[,,] {
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var sourceImg = ops.convert_to_tensor(sourceImgArray);

// 创建 lrImg 数组
var lrImgArray = new float[,,] {
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var lrImg = ops.convert_to_tensor(lrImgArray);

var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 1);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

// 创建 udImg 数组
var udImgArray = new float[,,] {
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var udImg = ops.convert_to_tensor(udImgArray);

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

var ud2 = tf.reverse(sourceImg, new Axis(0));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=0) fail.");

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=0 fail.");
}

[TestMethod]
public void ReverseImgArray4D()
{
// 原图左上角,加一张左右翻转后的图片
var m = new float[,,,] {
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var sourceImg = ops.convert_to_tensor(m);

var lrArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var lrImg = ops.convert_to_tensor(lrArray);

// 创建 ud 数组
var udArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
}
}
};
var udImg = ops.convert_to_tensor(udArray);

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

var ud2 = tf.reverse(sourceImg, new Axis(1));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

// 左右翻转
var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 0);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

}

[TestMethod]
public void ReverseImgArray4D_3x3()
{
// 原图左上角,加一张左右翻转后的图片
var m = new float[,,,] {
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var sourceImg = ops.convert_to_tensor(m);

var lrArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var lrImg = ops.convert_to_tensor(lrArray);

// 创建 ud 数组
var udArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{ {
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
}
}
};
var udImg = ops.convert_to_tensor(udArray);

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

var ud2 = tf.reverse(sourceImg, new Axis(1));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

// 左右翻转
var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 0);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

}
} }
} }

+ 44
- 0
test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs View File

@@ -0,0 +1,44 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow;

namespace TensorFlowNET.UnitTest.NumPy
{
[TestClass]
public class ShapeTest : EagerModeTestBase
{
[Ignore]
[TestMethod]
public unsafe void ShapeGetLastElements()
{
// test code from function _CheckAtLeast3DImage
// 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的
// todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试

var image_shape = new Shape(new[] { 32, 64, 3 });
var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 });

var image_shape_last_three_elements = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 1]});

var image_shape_last_three_elements2 = image_shape["-3:"];

Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail.");

var image_shape_last_three_elements_4d = new Shape(new[] {
image_shape_4d.dims[image_shape_4d.dims.Length - 3],
image_shape_4d.dims[image_shape_4d.dims.Length - 2],
image_shape_4d.dims[image_shape_4d.dims.Length - 1]});

var image_shape_last_three_elements2_4d = image_shape_4d["-3:"];

Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail.");
}

}
}

Loading…
Cancel
Save