Browse Source

add tf.image.decode_image #350

tags/v0.12
Oceania2018 6 years ago
parent
commit
9d90d74ddd
9 changed files with 332 additions and 1 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/tf.image.cs
  2. +32
    -0
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  3. +63
    -0
      src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs
  4. +42
    -0
      src/TensorFlowNET.Core/Operations/gen_string_ops.cs
  5. +105
    -0
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  6. +6
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  7. +38
    -0
      src/TensorFlowNET.Core/Operations/string_ops.cs
  8. +2
    -1
      test/TensorFlowNET.UnitTest/GraphTest.cs
  9. +30
    -0
      test/TensorFlowNET.UnitTest/ImageTest.cs

+ 14
- 0
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -42,6 +42,20 @@ namespace Tensorflow


public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null)
=> gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name);

public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
string name = null, bool expand_animations = true)
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
name: name, expand_animations: expand_animations);

/// <summary>
/// Convenience function to check if the 'contents' encodes a JPEG image.
/// </summary>
/// <param name="contents"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor is_jpeg(Tensor contents, string name = null)
=> image_ops_impl.is_jpeg(contents, name: name);
} }
} }
} }

+ 32
- 0
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -0,0 +1,32 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using Tensorflow.IO;

namespace Tensorflow
{
public partial class tensorflow
{
public strings_internal strings = new strings_internal();
public class strings_internal
{
public Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
=> string_ops.substr(input, pos, len, name: name, @uint: @uint);
}
}
}

+ 63
- 0
src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs View File

@@ -88,6 +88,69 @@ namespace Tensorflow
} }
} }


public static Tensor decode_gif(Tensor contents,
string name = null)
{
// Add nodes to the TensorFlow graph.
if (tf.context.executing_eagerly())
{
throw new NotImplementedException("decode_gif");
}
else
{
var _op = _op_def_lib._apply_op_helper("DecodeGif", name: name, args: new
{
contents
});

return _op.output;
}
}

public static Tensor decode_png(Tensor contents,
int channels = 0,
TF_DataType dtype = TF_DataType.TF_UINT8,
string name = null)
{
// Add nodes to the TensorFlow graph.
if (tf.context.executing_eagerly())
{
throw new NotImplementedException("decode_png");
}
else
{
var _op = _op_def_lib._apply_op_helper("DecodePng", name: name, args: new
{
contents,
channels,
dtype
});

return _op.output;
}
}

public static Tensor decode_bmp(Tensor contents,
int channels = 0,
string name = null)
{
// Add nodes to the TensorFlow graph.
if (tf.context.executing_eagerly())
{
throw new NotImplementedException("decode_bmp");
}
else
{
var _op = _op_def_lib._apply_op_helper("DecodeBmp", name: name, args: new
{
contents,
channels
});

return _op.output;
}
}

public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) public static Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null)
{ {
if (tf.context.executing_eagerly()) if (tf.context.executing_eagerly())


+ 42
- 0
src/TensorFlowNET.Core/Operations/gen_string_ops.cs View File

@@ -0,0 +1,42 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_string_ops
{
static readonly OpDefLibrary _op_def_lib;
static gen_string_ops() { _op_def_lib = new OpDefLibrary(); }

public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
{
var _op = _op_def_lib._apply_op_helper("Substr", name: name, args: new
{
input,
pos,
len,
unit = @uint
});

return _op.output;
}
}
}

+ 105
- 0
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -17,11 +17,116 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public class image_ops_impl public class image_ops_impl
{ {
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
string name = null, bool expand_animations = true)
{
Tensor substr = null;


Func<ITensorOrOperation> _jpeg = () =>
{
int jpeg_channels = channels;
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
{
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
});
};

Func<ITensorOrOperation> _gif = () =>
{
int gif_channels = channels;
var good_channels = math_ops.logical_and(
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));

string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
{
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
if (!expand_animations)
// result = array_ops.gather(result, 0);
throw new NotImplementedException("");
return result;
});
};

Func<ITensorOrOperation> _bmp = () =>
{
int bmp_channels = channels;
var signature = string_ops.substr(contents, 0, 2);
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
{
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
});
};

Func<ITensorOrOperation> _png = () =>
{
return convert_image_dtype(gen_image_ops.decode_png(
contents,
channels,
dtype: dtype),
dtype);
};

Func<ITensorOrOperation> check_gif = () =>
{
var is_gif = math_ops.equal(substr, "\x47\x49\x46", name: "is_gif");
return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif");
};

Func<ITensorOrOperation> check_png = () =>
{
return control_flow_ops.cond(_is_png(contents), _png, check_gif, name: "cond_png");
};

return tf_with(ops.name_scope(name, "decode_image"), scope =>
{
substr = string_ops.substr(contents, 0, 3);
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
});
}

public static Tensor is_jpeg(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>
{
var substr = string_ops.substr(contents, 0, 3);
return math_ops.equal(substr, "\xff\xd8\xff", name: name);
});
}

public static Tensor _is_png(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "is_png"), scope =>
{
var substr = string_ops.substr(contents, 0, 3);
return math_ops.equal(substr, @"\211PN", name: name);
});
}

public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false,
string name = null)
{
if (dtype == image.dtype)
return array_ops.identity(image, name: name);

throw new NotImplementedException("");
}
} }
} }

+ 6
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -168,6 +168,9 @@ namespace Tensorflow
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(x, y, name: name); => gen_math_ops.mul(x, y, name: name);


public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.not_equal(x, y, name: name);

public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null) public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul_no_nan(x, y, name: name); => gen_math_ops.mul_no_nan(x, y, name: name);


@@ -264,6 +267,9 @@ namespace Tensorflow
return gen_math_ops.log(x, name); return gen_math_ops.log(x, name);
} }


public static Tensor logical_and(Tensor x, Tensor y, string name = null)
=> gen_math_ops.logical_and(x, y, name: name);

public static Tensor lgamma(Tensor x, string name = null) public static Tensor lgamma(Tensor x, string name = null)
=> gen_math_ops.lgamma(x, name: name); => gen_math_ops.lgamma(x, name: name);




+ 38
- 0
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -0,0 +1,38 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class string_ops
{
/// <summary>
/// Return substrings from `Tensor` of strings.
/// </summary>
/// <param name="input"></param>
/// <param name="pos"></param>
/// <param name="len"></param>
/// <param name="name"></param>
/// <param name="uint"></param>
/// <returns></returns>
public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint);
}
}

+ 2
- 1
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -416,12 +416,13 @@ namespace TensorFlowNET.UnitTest


} }


[TestMethod]
public void ImportGraphMeta() public void ImportGraphMeta()
{ {
var dir = "my-save-dir/"; var dir = "my-save-dir/";
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
var new_saver = tf.train.import_meta_graph(@"D:\tmp\resnet_v2_101_2017_04_14\eval.graph");
new_saver.restore(sess, dir + "my-model-10000"); new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels); var batch_size = tf.size(labels);


+ 30
- 0
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -0,0 +1,30 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class ImageTest
{
string imgPath = "../../../../../data/shasta-daisy.jpg";
Tensor contents;

public ImageTest()
{
imgPath = Path.GetFullPath(imgPath);
contents = tf.read_file(imgPath);
}

[TestMethod]
public void decode_image()
{
var img = tf.image.decode_image(contents);
Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0");
}
}
}

Loading…
Cancel
Save