Browse Source

implement more of tf.image

^
tags/v0.20
carb0n Haiping 5 years ago
parent
commit
c6601f9099
6 changed files with 2320 additions and 18 deletions
  1. +157
    -4
      src/TensorFlowNET.Core/APIs/tf.image.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +17
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +2034
    -14
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  5. +102
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  6. +5
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs

+ 157
- 4
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.


Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,6 +25,136 @@ namespace Tensorflow


public class image_internal public class image_internal
{ {
public Tensor random_flip_up_down(Tensor image, int seed = 0)
=> image_ops_impl.random_flip_up_down(image, seed);

public Tensor random_flip_left_right(Tensor image, int seed = 0)
=> image_ops_impl.random_flip_left_right(image, seed);

public Tensor flip_left_right(Tensor image)
=> image_ops_impl.flip_left_right(image);
public Tensor flip_up_down(Tensor image)
=> image_ops_impl.flip_up_down(image);

public Tensor rot90(Tensor image, int k = 1, string name = null)
=> image_ops_impl.rot90(image, k, name);

public Tensor transpose(Tensor image, string name = null)
=> image_ops_impl.transpose(image, name);

public Tensor central_crop(Tensor image, float central_fraction)
=> image_ops_impl.central_crop(image, central_fraction);

public Tensor pad_to_bounding_box(Tensor image, int offset_height, int offset_width, int target_height, int target_width)
=> image_ops_impl.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width);

public Tensor crop_to_bounding_box(Tensor image, int offset_height, int offset_width, int target_height, int target_width)
=> image_ops_impl.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width);

public Tensor resize_image_with_crop_or_pad(Tensor image, object target_height, object target_width)
=> image_ops_impl.resize_image_with_crop_or_pad(image, target_height, target_width);

public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false,
string name = null)
=> image_ops_impl.resize_images(images, size, method, preserve_aspect_ratio, antialias, name);

public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias)
=> image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias);

public Tensor per_image_standardization(Tensor image)
=> image_ops_impl.per_image_standardization(image);

public Tensor random_brightness(Tensor image, float max_delta, int seed = 0)
=> image_ops_impl.random_brightness(image, max_delta, seed);

public Tensor random_contrast(Tensor image, float lower, float upper, int seed = 0)
=> image_ops_impl.random_contrast(image, lower, upper, seed);

public Tensor adjust_brightness(Tensor image, Tensor delta)
=> image_ops_impl.adjust_brightness(image, delta);

public Tensor adjust_contrast(Tensor images, Tensor contrast_factor)
=> image_ops_impl.adjust_contrast(images, contrast_factor);

public Tensor adjust_gamma(Tensor image, int gamma = 1, int gain = 1)
=> image_ops_impl.adjust_gamma(image, gamma, gain);

public Tensor rgb_to_grayscale(Tensor images, string name = null)
=> image_ops_impl.rgb_to_grayscale(images, name);

public Tensor grayscale_to_rgb(Tensor images, string name = null)
=> image_ops_impl.grayscale_to_rgb(images, name);

public Tensor random_hue(Tensor image, float max_delta, int seed = 0)
=> image_ops_impl.random_hue(image, max_delta, seed);
public Tensor adjust_hue(Tensor image, Tensor delta, string name = null)
=> image_ops_impl.adjust_hue(image, delta, name);

public Tensor random_jpeg_quality(Tensor image, float min_jpeg_quality, float max_jpeg_quality, int seed = 0)
=> image_ops_impl.random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed);

public Tensor adjust_jpeg_quality(Tensor image, Tensor jpeg_quality, string name = null)
=> image_ops_impl.adjust_jpeg_quality(image, jpeg_quality, name);

public Tensor random_saturation(Tensor image, float lower, float upper, int seed = 0)
=> image_ops_impl.random_saturation(image, lower, upper, seed);

public Tensor adjust_saturation(Tensor image, Tensor saturation_factor, string name = null)
=> image_ops_impl.adjust_saturation(image, saturation_factor, name);

public Tensor total_variation(Tensor images, string name = null)
=> image_ops_impl.total_variation(images, name);

public (Tensor, Tensor, Tensor) sample_distorted_bounding_box(Tensor image_size, Tensor bounding_boxes,
int seed = 0,
Tensor min_object_covered = null,
float[] aspect_ratio_range = null,
float[] area_range = null,
int max_attempts = 100,
bool use_image_if_no_bounding_boxes = false,
string name = null)
=> image_ops_impl.sample_distorted_bounding_box_v2(image_size, bounding_boxes, seed, min_object_covered, aspect_ratio_range,
area_range, max_attempts, use_image_if_no_bounding_boxes, name);

public Tensor non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size, float iou_threshold = 0.5f,
float score_threshold = -1f / 0f, /*float soft_nms_sigma = 0.0f,*/ string name = null)
=> image_ops_impl.non_max_suppression(boxes, scores, max_output_size, iou_threshold, score_threshold, name);

public Tensor non_max_suppression_with_overlaps(Tensor overlaps, Tensor scores, Tensor max_output_size,
float overlap_threshold = 0.5f, float score_threshold = -1 / 0f, string name = null)
=> image_ops_impl.non_max_suppression_with_overlaps(overlaps, scores, max_output_size, overlap_threshold, score_threshold, name);

public Tensor rgb_to_yiq(Tensor images)
=> image_ops_impl.rgb_to_yiq(images);

public Tensor yiq_to_rgb(Tensor images)
=> image_ops_impl.yiq_to_rgb(images);

public Tensor rgb_to_yuv(Tensor images)
=> image_ops_impl.rgb_to_yuv(images);

public Tensor yuv_to_rgb(Tensor images)
=> image_ops_impl.yuv_to_rgb(images);

public Tensor psnr(Tensor a, Tensor b, Tensor max_val, string name = null)
=> image_ops_impl.psnr(a, b, max_val, name);

public Tensor ssim(Tensor img1, Tensor img2, float max_val = 1f, float filter_size = 11f, float filter_sigma = 1.5f,
float k1 = 0.01f, float k2 = 0.03f)
=> image_ops_impl.ssim(img1, img2, max_val, filter_size, filter_sigma, k1, k2);

public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] power_factors = null, float filter_size = 11f,
float filter_sigma = 1.5f, float k1 = 0.01f, float k2 = 0.03f)
=> image_ops_impl.ssim_multiscale(img1, img2, max_val, power_factors, filter_size, filter_sigma, k1, k2);

public (Tensor, Tensor) image_gradients(Tensor image)
=> image_ops_impl.image_gradients(image);

public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);

public Tensor decode_jpeg(Tensor contents, public Tensor decode_jpeg(Tensor contents,
int channels = 0, int channels = 0,
int ratio = 1, int ratio = 1,
@@ -52,14 +182,34 @@ namespace Tensorflow
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);


public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise, string name = null)
=> image_ops_impl.extract_glimpse(input, size, offsets, centered, normalized, uniform_noise, name);

public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(Tensor boxes, Tensor scores, Tensor max_output_size_per_class,
Tensor max_total_size, float iou_threshold = 0.5f, float score_threshold = -1f / 0f, bool pad_per_class = false, bool clip_boxes = true,
string name = null)
=> image_ops_impl.combined_non_max_suppression(boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold,
pad_per_class, clip_boxes, name);

public (Tensor, Tensor) non_max_suppression_padded(Tensor boxes, Tensor scores, Tensor max_output_size,
float iou_threshold = 0.5f,
float score_threshold = -1f / 0f,
bool pad_to_max_output_size = false,
string name = null,
bool sorted_input = false,
bool canonicalized_coordinates = false,
int tile_size = 512)
=> image_ops_impl.non_max_suppression_padded(boxes, scores, max_output_size, iou_threshold, score_threshold, pad_to_max_output_size,
name, sorted_input, canonicalized_coordinates, tile_size);


public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null)
=> gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name);


public Tensor resize_images(Tensor images, Tensor size, ResizeMethod method = ResizeMethod.BILINEAR,
bool align_corners = false, bool preserve_aspect_ratio = false, string name = null)
public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR,
bool preserve_aspect_ratio = false, string name = null)
=> image_ops_impl.resize_images(images, size, method: method, => image_ops_impl.resize_images(images, size, method: method,
align_corners: align_corners, preserve_aspect_ratio: preserve_aspect_ratio, name: name);
preserve_aspect_ratio: preserve_aspect_ratio, name: name);


public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null)
=> gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name);
@@ -91,6 +241,9 @@ namespace Tensorflow
string name = null, bool half_pixel_centers = false) string name = null, bool half_pixel_centers = false)
=> image_ops_impl.resize_nearest_neighbor(images, size, align_corners: align_corners, => image_ops_impl.resize_nearest_neighbor(images, size, align_corners: align_corners,
name: name, half_pixel_centers: half_pixel_centers); name: name, half_pixel_centers: half_pixel_centers);

public Tensor draw_bounding_boxes(Tensor images, Tensor boxes, Tensor colors = null, string name = null)
=> image_ops_impl.draw_bounding_boxes(images, boxes, colors, name);
} }
} }
} }

+ 5
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -534,6 +534,11 @@ namespace Tensorflow
return gen_array_ops.size(input, name: name, out_type: out_type); return gen_array_ops.size(input, name: name, out_type: out_type);
}); });
} }
public static Tensor tile(Tensor input, Tensor multiples, string name = null)
{
throw new NotImplementedException("tile");
}


public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{ {


+ 17
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -1083,6 +1083,23 @@ namespace Tensorflow


return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Pow", name,
null,
x, y);

return results[0];
}

var _op = tf._op_def_lib._apply_op_helper("Pow", name, args: new { x, y });

return _op.outputs[0];
}


public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null) public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null)
{ {


+ 2034
- 14
src/TensorFlowNET.Core/Operations/image_ops_impl.cs
File diff suppressed because it is too large
View File


+ 102
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -17,6 +17,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Framework; using Tensorflow.Framework;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -67,6 +68,15 @@ namespace Tensorflow


return gen_math_ops.add_n(inputs, name: name); return gen_math_ops.add_n(inputs, name: name);
} }
public static Tensor round(Tensor x, string name = null)
{
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.is_integer())
return x;
else
return gen_math_ops.round(x, name: name);
}


public static Tensor cast(IVariableV1 x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) public static Tensor cast(IVariableV1 x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{ {
@@ -118,6 +128,24 @@ namespace Tensorflow
return x; return x;
}); });
} }
public static Tensor saturate_cast(Tensor value, TF_DataType dtype, string name = null)
{
return tf_with(ops.name_scope(name, "saturate_cast", new [] {value}), name =>
{
value = ops.convert_to_tensor(value, name: "value");
// dtype = dtypes.as_dtype(dtype).as_base_dtype();
if (value.dtype.min() < dtype.min())
value = gen_math_ops.maximum(
value,
ops.convert_to_tensor(dtype.min(), dtype: value.dtype, name: "min"));
if (value.dtype.max() > dtype.max())
value = gen_math_ops.minimum(
value,
ops.convert_to_tensor(dtype.max(), dtype: value.dtype, name: "max"));
return cast(value, dtype, name: name);
});
}


public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{ {
@@ -351,6 +379,19 @@ namespace Tensorflow
return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims); return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims);
}); });
} }
public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
{
if (name == null)
name = "reduce_std";
// else {name = name;}

return tf_with(ops.name_scope(name, "reduce_std", new [] {input_tensor}), scope =>
{
var variance = reduce_variance(input_tensor, axis: axis, keepdims: keepdims);
return gen_math_ops.sqrt(variance);
});
}


public static Tensor sigmoid<T>(T x, string name = null) public static Tensor sigmoid<T>(T x, string name = null)
=> tf_with(ops.name_scope(name, "Sigmoid", x), scope => => tf_with(ops.name_scope(name, "Sigmoid", x), scope =>
@@ -812,6 +853,67 @@ namespace Tensorflow


public static Tensor tanh(Tensor x, string name = null) public static Tensor tanh(Tensor x, string name = null)
=> gen_math_ops.tanh(x, name); => gen_math_ops.tanh(x, name);
public static Tensor tensordot(Tensor x, Tensor y, int[] axes, string name = null)
{
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
{
if (a.TensorShape.is_fully_defined() && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
{
var shape_a = a.TensorShape.as_list();
// axes
int iter = 0;
foreach (int i in axes)
{
if (i >= 0)
axes[0 + iter] = i;
else
axes[0 + iter] = i + len(shape_a);
iter++;
}
// free
int[] free = {};
iter = 0;
foreach (int i in Enumerable.Range(0, len(axes)))
if (!Array.Exists(axes, i => i == i))
free[free.Length] = i;

// free_dims
int[] free_dims = {};
foreach (int i in free)
free_dims[free_dims.Length] = shape_a[i];

int prod_free = (int)np.prod(free_dims);
// prod_axes
int[] prod_axes_pre = {};
foreach (int i in axes)
prod_axes_pre[prod_axes_pre.Length] = shape_a[i];
int prod_axes = (int)np.prod(prod_axes_pre);
// perm
Tensor perm;
if (flipped)
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
else
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
+ ops.convert_to_tensor(list(axes));

// new_shape
TensorShape new_shape;
if (flipped)
new_shape = new TensorShape(new int[] {prod_axes, prod_free});
else
new_shape = new TensorShape(new int[] {prod_free, prod_axes});
}

throw new NotImplementedException("_tensordot_reshape");
}

throw new NotImplementedException("tensordot");
}


public static Tensor truediv(Tensor x, Tensor y, string name = null) public static Tensor truediv(Tensor x, Tensor y, string name = null)
=> _truediv_python3(x, y, name); => _truediv_python3(x, y, name);


+ 5
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -228,6 +228,11 @@ namespace Tensorflow
{ {
return (int)type < 100 ? (TF_DataType)((int)type + 100) : type; return (int)type < 100 ? (TF_DataType)((int)type + 100) : type;
} }
public static long min(this TF_DataType type)
{
throw new NotImplementedException($"min {type.name()}");
}


public static long max(this TF_DataType type) public static long max(this TF_DataType type)
{ {


Loading…
Cancel
Save