diff --git a/data/img001.bmp b/data/img001.bmp new file mode 100644 index 00000000..86cfa972 Binary files /dev/null and b/data/img001.bmp differ diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index ac9cbc60..41ef5296 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -339,6 +339,13 @@ namespace Tensorflow => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, name: name, expand_animations: expand_animations); + public Tensor encode_png(Tensor contents, string name = null) + => image_ops_impl.encode_png(contents, name: name); + + public Tensor encode_jpeg(Tensor contents, string name = null) + => image_ops_impl.encode_jpeg(contents, name: name); + + /// /// Convenience function to check if the 'contents' encodes a JPEG image. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index be1e86e6..4cb18070 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using Tensorflow.IO; +using Tensorflow.Operations; namespace Tensorflow { @@ -46,6 +47,12 @@ namespace Tensorflow public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); + + public void write_file(string filename, Tensor conentes, string name = null) + => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name); + + public void write_file(Tensor filename, Tensor conentes, string name = null) + => gen_ops.write_file(filename, conentes, name); } public GFile gfile = new GFile(); diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 5fa4c97d..ed756740 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -39083,13 +39083,14 @@ namespace Tensorflow.Operations /// /// creates directory if not existing. /// - public static Operation write_file(Tensor filename, Tensor contents, string name = "WriteFile") + public static Tensor write_file(Tensor filename, Tensor contents, string name = "WriteFile") { var dict = new Dictionary(); dict["filename"] = filename; dict["contents"] = contents; var op = tf.OpDefLib._apply_op_helper("WriteFile", name: name, keywords: dict); - return op; + op.run(); + return op.output; } /// diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 318b8b14..0c827b7f 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -2047,6 +2047,22 @@ new_height, new_width"); }); } + public static Tensor encode_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_jpeg"), scope => + { + return gen_ops.encode_jpeg(contents, name:name); + }); + } + + public static Tensor encode_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_png"), scope => + { + return gen_ops.encode_png(contents, name: name); + }); + } + public static Tensor is_jpeg(Tensor contents, string name = null) { return tf_with(ops.name_scope(name, "is_jpeg"), scope => diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs index d671b609..6a430443 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs @@ -4,6 +4,7 @@ using System.Linq; using Tensorflow; using static Tensorflow.Binding; using System; +using System.IO; namespace TensorFlowNET.UnitTest { @@ -164,5 +165,32 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(result.size, 16ul); Assert.AreEqual(result[0, 0, 0, 0], 12f); } + + [TestMethod] + public void ImageSaveTest() + { + var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); + var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg"); + var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png"); + + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + + var contents = tf.io.read_file(imgPath); + var bmp = tf.image.decode_image(contents); + Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0"); + + var jpeg = tf.image.encode_jpeg(bmp); + tf.io.write_file(jpegImgPath, jpeg); + Assert.IsTrue(File.Exists(jpegImgPath)); + + var png = tf.image.encode_png(bmp); + tf.io.write_file(pngImgPath, png); + Assert.IsTrue(File.Exists(pngImgPath)); + + // 如果要测试图片正确性,可以注释下面两行代码 + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + } } }