Browse Source

feat: add adjust_contrast, adjust_hue, combined_non_max_suppression, crop_and_resize image oprs

tags/v0.110.4-Transformer-Model
“Wanglongzhi2001” 2 years ago
parent
commit
b0ce73caff
3 changed files with 479 additions and 15 deletions
  1. +119
    -12
      src/TensorFlowNET.Core/APIs/tf.image.cs
  2. +296
    -2
      src/TensorFlowNET.Core/Operations/gen_image_ops.cs
  3. +64
    -1
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs

+ 119
- 12
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -14,6 +14,10 @@
limitations under the License.
******************************************************************************/

using OneOf.Types;
using System;
using System.Buffers.Text;
using Tensorflow.Contexts;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -162,17 +166,108 @@ namespace Tensorflow
public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);
/// <summary>
/// Adjust contrast of RGB or grayscale images.
/// </summary>
/// <param name="images">Images to adjust. At least 3-D.</param>
/// <param name="contrast_factor"></param>
/// <param name="name">A float multiplier for adjusting contrast.</param>
/// <returns>The contrast-adjusted image or images.</returns>
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name);

/// <summary>
/// Adjust hue of RGB images.
/// </summary>
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="delta">float. How much to add to the hue channel.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception>
public Tensor adjust_hue(Tensor images, float delta, string name = null)
{
if (tf.Context.executing_eagerly())
{
if (delta < -1f || delta > 1f)
throw new ValueError("delta must be in the interval [-1, 1]");
}
return gen_image_ops.adjust_hue(images, delta, name: name);
}

/// <summary>
/// Adjust saturation of RGB images.
/// </summary>
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
=> gen_image_ops.adjust_saturation(image, saturation_factor, name);

/// <summary>
/// Greedily selects a subset of bounding boxes in descending order of score.
/// </summary>
/// <param name="boxes">
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal
/// to number of classes, class-specific boxes are used.
/// </param>
/// <param name="scores">
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
/// representing a single score corresponding to each box(each row of boxes).
/// </param>
/// <param name="max_output_size_per_class">
/// A scalar integer `Tensor` representing the
/// maximum number of boxes to be selected by non-max suppression per class
/// </param>
/// <param name="max_total_size">
/// A int32 scalar representing maximum number of boxes retained
/// over all classes.Note that setting this value to a large number may
/// result in OOM error depending on the system workload.
/// </param>
/// <param name="iou_threshold">
/// A float representing the threshold for deciding whether boxes
/// overlap too much with respect to IOU.
/// </param>
/// <param name="score_threshold">
/// A float representing the threshold for deciding when to
/// remove boxes based on score.
/// </param>
/// <param name="pad_per_class">
/// If false, the output nmsed boxes, scores and classes are
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
/// </param>
/// <param name="clip_boxes">
/// If true, the coordinates of output nmsed boxes will be clipped
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
/// </param>
/// <returns>
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of
/// valid detections per batch item. Only the top valid_detections[i] entries
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
/// entries are zero paddings.
/// </returns>
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
Tensor boxes,
Tensor scores,
int max_output_size_per_class,
int max_total_size,
float iou_threshold,
float score_threshold,
bool pad_per_class = false,
bool clip_boxes = true)
{
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
var max_total_size_t = ops.convert_to_tensor(max_total_size);
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
}

/// <summary>
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
@@ -187,7 +282,19 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns>
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);

public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise = true, string name = null)


+ 296
- 2
src/TensorFlowNET.Core/Operations/gen_image_ops.cs View File

@@ -16,18 +16,312 @@

using System;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;
using Tensorflow.Exceptions;
using Tensorflow.Contexts;
using System.Xml.Linq;
using Google.Protobuf;

namespace Tensorflow
{
public class gen_image_ops
{
public static Tensor adjust_contrastv2(Tensor images, Tensor contrast_factor, string name = null)
{
var _ctx = tf.Context;
if (_ctx.executing_eagerly())
{
try
{
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustContrastv2", name) {
args = new object[] { images, contrast_factor }, attrs = new Dictionary<string, object>() { } });
return _fast_path_result[0];
}
catch (NotOkStatusException ex)
{
throw ex;
}
catch (Exception)
{
}
try
{
return adjust_contrastv2_eager_fallback(images, contrast_factor, name: name, ctx: _ctx);
}
catch (Exception)
{
}
}
Dictionary<string, object> keywords = new();
keywords["images"] = images;
keywords["contrast_factor"] = contrast_factor;
var _op = tf.OpDefLib._apply_op_helper("AdjustContrastv2", name, keywords);
var _result = _op.outputs;
if (_execute.must_record_gradient())
{
object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
_execute.record_gradient("AdjustContrastv2", _op.inputs, _attrs, _result);
}
return _result[0];
}
public static Tensor adjust_contrastv2(Tensor image, float contrast_factor, string name = null)
{
return adjust_contrastv2(image, tf.convert_to_tensor(contrast_factor), name: name);
}

public static Tensor adjust_contrastv2_eager_fallback(Tensor images, Tensor contrast_factor, string name, Context ctx)
{
Tensor[] _inputs_flat = new Tensor[] { images, contrast_factor};
object[] _attrs = new object[] { "T", images.dtype };
var _result = _execute.execute("AdjustContrastv2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
if (_execute.must_record_gradient())
{
_execute.record_gradient("AdjustContrastv2", _inputs_flat, _attrs, _result);
}
return _result[0];
}

public static Tensor adjust_hue(Tensor images, Tensor delta, string name = null)
{
var _ctx = tf.Context;
if (_ctx.executing_eagerly())
{
try
{
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustHue", name) {
args = new object[] { images, delta }, attrs = new Dictionary<string, object>() { } });
return _fast_path_result[0];
}
catch (NotOkStatusException ex)
{
throw ex;
}
catch (Exception)
{
}
try
{
return adjust_hue_eager_fallback(images, delta, name: name, ctx: _ctx);
}
catch (Exception)
{
}
}
Dictionary<string, object> keywords = new();
keywords["images"] = images;
keywords["delta"] = delta;
var _op = tf.OpDefLib._apply_op_helper("AdjustHue", name, keywords);
var _result = _op.outputs;
if (_execute.must_record_gradient())
{
object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
_execute.record_gradient("AdjustHue", _op.inputs, _attrs, _result);
}
return _result[0];
}

public static Tensor adjust_hue(Tensor images, float delta, string name = null)
=> adjust_hue(images, delta, name: name);

public static Tensor adjust_hue_eager_fallback(Tensor images, Tensor delta, string name, Context ctx)
{
Tensor[] _inputs_flat = new Tensor[] { images, delta};
object[] _attrs = new object[] { "T", images.dtype };
var _result = _execute.execute("AdjustHue", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
if (_execute.must_record_gradient())
{
_execute.record_gradient("AdjustHue", _inputs_flat, _attrs, _result);
}
return _result[0];
}

public static Tensor adjust_saturation(Tensor images, Tensor scale, string name = null)
{
var _ctx = tf.Context;
if (_ctx.executing_eagerly())
{
try
{
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "AdjustSaturation", name)
{
args = new object[] { images, scale },
attrs = new Dictionary<string, object>() { }
});
return _fast_path_result[0];
}
catch (NotOkStatusException ex)
{
throw ex;
}
catch (Exception)
{
}
try
{
return adjust_hue_eager_fallback(images, scale, name: name, ctx: _ctx);
}
catch (Exception)
{
}
}
Dictionary<string, object> keywords = new();
keywords["images"] = images;
keywords["scale"] = scale;
var _op = tf.OpDefLib._apply_op_helper("AdjustSaturation", name, keywords);
var _result = _op.outputs;
if (_execute.must_record_gradient())
{
object[] _attrs = new object[] { "T", _op._get_attr_type("T") };
_execute.record_gradient("AdjustSaturation", _op.inputs, _attrs, _result);
}
return _result[0];
}

public static Tensor adjust_saturation(Tensor images, float scale, string name = null)
=> adjust_saturation(images, ops.convert_to_tensor(scale), name: name);

public static Tensor adjust_saturation_eager_fallback(Tensor images, Tensor scale, string name, Context ctx)
{
Tensor[] _inputs_flat = new Tensor[] { images, scale };
object[] _attrs = new object[] { "T", images.dtype };
var _result = _execute.execute("AdjustSaturation", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
if (_execute.must_record_gradient())
{
_execute.record_gradient("AdjustSaturation", _inputs_flat, _attrs, _result);
}
return _result[0];
}

public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size,
Tensor iou_threshold, Tensor score_threshold, bool pad_per_class, bool clip_boxes)
Tensor iou_threshold, Tensor score_threshold, bool pad_per_class = false, bool clip_boxes = true, string name = null)
{
throw new NotImplementedException("combined_non_max_suppression");
var _ctx = tf.Context;
if (_ctx.executing_eagerly())
{
try
{
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CombinedNonMaxSuppression", name){
args = new object[] {
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold,
"pad_per_class", pad_per_class, "clip_boxes", clip_boxes},
attrs = new Dictionary<string, object>() { }});
return (_fast_path_result[0], _fast_path_result[1], _fast_path_result[2], _fast_path_result[3]);
}
catch (NotOkStatusException ex)
{
throw ex;
}
catch (Exception)
{
}
try
{
return combined_non_max_suppression_eager_fallback(
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
score_threshold, pad_per_class, clip_boxes, name, ctx: _ctx);
}
catch (Exception)
{
}
}
Dictionary<string, object> keywords = new();
keywords["boxes"] = boxes;
keywords["scores"] = scores;
keywords["max_output_size_per_class"] = max_output_size_per_class;
keywords["max_total_size"] = max_total_size;
keywords["iou_threshold"] = iou_threshold;
keywords["score_threshold"] = score_threshold;
keywords["pad_per_class"] = pad_per_class;
keywords["clip_boxes"] = clip_boxes;

var _op = tf.OpDefLib._apply_op_helper("CombinedNonMaxSuppression", name, keywords);
var _result = _op.outputs;
if (_execute.must_record_gradient())
{
object[] _attrs = new object[] { "pad_per_class", _op._get_attr_type("pad_per_class") ,"clip_boxes", _op._get_attr_type("clip_boxes")};
_execute.record_gradient("CombinedNonMaxSuppression", _op.inputs, _attrs, _result);
}
return (_result[0], _result[1], _result[2], _result[3]);
}

public static (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression_eager_fallback(Tensor boxes, Tensor scores, Tensor max_output_size_per_class, Tensor max_total_size,
Tensor iou_threshold, Tensor score_threshold, bool pad_per_class, bool clip_boxes, string name, Context ctx)
{
Tensor[] _inputs_flat = new Tensor[] { boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold };
object[] _attrs = new object[] { "pad_per_class", pad_per_class, "clip_boxes", clip_boxes };
var _result = _execute.execute("CombinedNonMaxSuppression", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
if (_execute.must_record_gradient())
{
_execute.record_gradient("CombinedNonMaxSuppression", _inputs_flat, _attrs, _result);
}
return (_result[0], _result[1], _result[2], _result[3]);
}

public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null)
{
var _ctx = tf.Context;
if (_ctx.executing_eagerly())
{
try
{
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "CropAndResize", name) {
args = new object[] {
image, boxes, box_ind, crop_size, "method", method, "extrapolation_value", extrapolation_value }, attrs = new Dictionary<string, object>() { } });
return _fast_path_result[0];
}
catch (NotOkStatusException ex)
{
throw ex;
}
catch (Exception)
{
}
try
{
return crop_and_resize_eager_fallback(
image, boxes, box_ind, crop_size, method: method, extrapolation_value: extrapolation_value, name: name, ctx: _ctx);
}
catch (Exception)
{
}
}
Dictionary<string, object> keywords = new();
keywords["image"] = image;
keywords["boxes"] = boxes;
keywords["box_ind"] = box_ind;
keywords["crop_size"] = crop_size;
keywords["method"] = method;
keywords["extrapolation_value"] = extrapolation_value;
var _op = tf.OpDefLib._apply_op_helper("CropAndResize", name, keywords);
var _result = _op.outputs;
if (_execute.must_record_gradient())
{
object[] _attrs = new object[] { "T", _op._get_attr_type("T") ,"method", _op._get_attr_type("method") ,
"extrapolation_value", _op.get_attr("extrapolation_value")};
_execute.record_gradient("CropAndResize", _op.inputs, _attrs, _result);
}
return _result[0];
}

public static Tensor crop_and_resize_eager_fallback(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method, float extrapolation_value, string name, Context ctx)
{
if (method is null)
method = "bilinear";
//var method_cpmpat = ByteString.CopyFromUtf8(method ?? string.Empty);
//var extrapolation_value_float = (float)extrapolation_value;

Tensor[] _inputs_flat = new Tensor[] { image, boxes, box_ind, crop_size, tf.convert_to_tensor(method), tf.convert_to_tensor(extrapolation_value) };
object[] _attrs = new object[] { "T", image.dtype };
var _result = _execute.execute("CropAndResize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name);
if (_execute.must_record_gradient())
{
_execute.record_gradient("CropAndResize", _inputs_flat, _attrs, _result);
}
return _result[0];
}


public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null)
{
if (dtype == image.dtype)


+ 64
- 1
test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using System;

namespace TensorFlowNET.UnitTest
{
@@ -22,13 +23,75 @@ namespace TensorFlowNET.UnitTest
contents = tf.io.read_file(imgPath);
}

[TestMethod]
public void adjust_contrast()
{
var input = np.array(0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f);
var image = tf.reshape(input, new int[] { 3, 3, 1 });
var img = tf.image.adjust_contrast(image, 2.0f);
var res = np.array(-4f, -2f, 0f, 2f, 4f, 6f, 8f, 10f, 12f).reshape((3,3,1));
Assert.AreEqual(img.numpy(), res);
}

[Ignore]
[TestMethod]
public void adjust_hue()
{
var image = tf.constant(new int[] {1,2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18});
image = tf.reshape(image, new int[] { 3, 2, 3 });
var adjusted_image = tf.image.adjust_hue(image, 0.2f);
var res = tf.constant(new int[] {2,1,3, 4, 5, 6,8,7,9,11,10,12,14,13,15,17,16,18});
res = tf.reshape(res,(3,2,3));
Assert.AreEqual(adjusted_image, res);
}

[TestMethod]
public void combined_non_max_suppression()
{
var boxesX = tf.constant(new float[,] { { 200, 100, 150, 100 }, { 220, 120, 150, 100 }, { 190, 110, 150, 100 },{ 210, 112, 150, 100 } });
var boxes1 = tf.reshape(boxesX, (1, 4, 1, 4));
var scoresX = tf.constant(new float[,] { { 0.2f, 0.7f, 0.1f },{ 0.1f, 0.8f, 0.1f },{ 0.3f, 0.6f, 0.1f },{ 0.05f, 0.9f, 0.05f } });
var scores1 = tf.reshape(scoresX, (1, 4, 3));
var (boxes, scores, classes, valid_detections) = tf.image.combined_non_max_suppression(boxes1, scores1, 10, 10, 0.5f, 0.2f, clip_boxes:false);

var boxes_gt = tf.constant(new float[,] { { 210f, 112f, 150f, 100f }, { 200f, 100f, 150f, 100f }, { 190f, 110f, 150f, 100f },
{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f, 0f},{ 0f, 0f, 0f , 0f},{ 0f, 0f, 0f, 0f},{ 0f , 0f, 0f, 0f},{ 0f, 0f, 0f, 0f} });
boxes_gt = tf.reshape(boxes_gt,(1, 10, 4));
Assert.AreEqual(boxes.numpy(), boxes_gt.numpy());
var scores_gt = tf.constant(new float[,] { { 0.9f, 0.7f, 0.3f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } });
scores_gt = tf.reshape(scores_gt, (1, 10));
Assert.AreEqual(scores.numpy(), scores_gt.numpy());
var classes_gt = tf.constant(new float[,] { { 1f, 1f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f } });
classes_gt = tf.reshape(classes_gt, (1, 10));
Assert.AreEqual(classes.numpy(), classes_gt.numpy());
var valid_detections_gt = tf.constant(new int[,] { { 3 } });
valid_detections_gt = tf.reshape(valid_detections_gt, (1));
Assert.AreEqual(valid_detections.numpy(), valid_detections_gt.numpy());
}

[TestMethod]
public void crop_and_resize()
{
int BATCH_SIZE = 1;
int NUM_BOXES = 5;
int IMAGE_HEIGHT = 256;
int IMAGE_WIDTH = 256;
int CHANNELS = 3;
var crop_size = tf.constant(new int[] { 24, 24 });
var image = tf.random.uniform((BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS));
var boxes = tf.random.uniform((NUM_BOXES, 4));
var box_ind = tf.random.uniform((NUM_BOXES), minval: 0, maxval: BATCH_SIZE, dtype: TF_DataType.TF_INT32);
var output = tf.image.crop_and_resize(image, boxes, box_ind, crop_size);
Assert.AreEqual((5,24,24,3), output.shape);
}

[TestMethod]
public void decode_image()
{
var img = tf.image.decode_image(contents);
Assert.AreEqual(img.name, "decode_image/DecodeImage:0");
}

[TestMethod]
public void resize_image()
{


Loading…
Cancel
Save