Browse Source

some more bug fixes

pull/571/head
carb0n GitHub 5 years ago
parent
commit
1b4d323f9f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 20 deletions
  1. +21
    -20
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+ 21
- 20
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -31,13 +31,14 @@ namespace Tensorflow

internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_static)
{
TensorShape image_shape;
try
{
if ( image.get_shape().NDims == null )
if ( image.shape.NDims == null )
{
var image_shape = image.get_shape().with_rank(3);
image_shape = image.shape.with_rank(3);
} else {
var image_shape = image.get_shape().with_rank_at_least(3);
image_shape = image.shape.with_rank_at_least(3);
}
}
catch (ValueError)
@@ -58,21 +59,21 @@ namespace Tensorflow
return new Operation[] {
check_ops.assert_positive(
array_ops.shape(image)[-3..],
@"inner 3 dims of 'image.shape'
must be > 0."),
new {@"inner 3 dims of 'image.shape'
must be > 0."}),
check_ops.assert_greater_equal(
array_ops.rank(image),
3,
ops.convert_to_tensor(3),
message: "'image' must be at least three-dimensional.")
};
} else {
return new Operation[] {};
}
}
internal static Tensor fix_image_flip_shape(Tensor image, Tensor result)
{
TensorShape image_shape = image.get_shape();
TensorShape image_shape = image.shape;
if (image_shape == tensor_shape.unknown_shape())
{
result.set_shape(new { null, null, null });
@@ -94,36 +95,36 @@ namespace Tensorflow
seed: seed,
scope_name: "random_flip_left_right");

internal static Tensor _random_flip(Tensor image, int flipindex, int seed,
internal static Tensor _random_flip(Tensor image, int flip_index, int seed,
string scope_name)
{
using ( var scope = ops.name_scope(null, scope_name, image))
using ( var scope = ops.name_scope(null, scope_name, new { image }) )
{
image = ops.convert_to_tensor(image, name: "image");
image = _AssertAtLeast3DImage(image);
var shape = image.get_shape();
Tensor shape = image.shape;
if ( shape.NDims == 3 || shape.NDims == null )
{
var uniform_random = random_ops.random_uniform(new Tensor [], 0, 1.0, seed: seed);
var uniform_random = random_ops.random_uniform(new {}, 0, 1.0, seed: seed);
var mirror_cond = math_ops.less(uniform_random, .5);
var result = control_flow_ops.cond(
pred: mirror_cond,
true_fn: array_ops.reverse(image, flipindex as int[]),
true_fn: array_ops.reverse(image, new { flip_index }),
false_fn: image,
name: scope
);
return fix_image_flip_shape(image, result);
} else if ( shape.NDims == 4 )
{
var batch_size = array_ops.shape(image)[0];
var uniform_random = random_ops.random_uniform(batch_size,
var batch_size = array_ops.shape(image);
var uniform_random = random_ops.random_uniform(batch_size[0],
0,
1.0,
1.0 as float,
seed: seed);
var flips = math_ops.round(
array_ops.reshape(uniform_random, shape: new Tensor [batch_size, 1, 1, 1]));
array_ops.reshape(uniform_random, shape: new Tensor [batch_size[0], 1, 1, 1]));
flips = math_ops.cast(flips, image.dtype);
var flipped_input = array_ops.reverse(image, flip_index + 1 as int[]);
var flipped_input = array_ops.reverse(image, new { flip_index + 1 });
return flips * flipped_input + (1 - flips) * image;
} else
{
@@ -131,7 +132,7 @@ namespace Tensorflow
}
}
}
public static Tensor flip_left_right(Tensor image)
=> _flip(image, 1, "flip_left_right");

@@ -144,7 +145,7 @@ namespace Tensorflow
{
image = ops.convert_to_tensor(image, name: "image");
image = _AssertAtLeast3DImage(image);
TensorShape shape = image.get_shape();
Tensor shape = image.shape;
if ( shape.NDims == 3 || shape.NDims == null )
{
return fix_image_flip_shape(image, array_ops.reverse(image, new { flip_index }));


Loading…
Cancel
Save