Browse Source

Merge pull request #1043 from AsakusaRinne/fix_1040

Partially fix the error when crop image.
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
34338c72c2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 17 deletions
  1. +21
    -17
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+ 21
- 17
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -542,32 +542,32 @@ or rank = 4. Had rank = {0}", rank));
image_shape));
}

var assert_ops = _CheckAtLeast3DImage(image, require_static: false);
var assert_ops = _CheckAtLeast3DImage(image, require_static: false).ToList();

// batch: [0], height: [1], width: [2], depth: [3]
var bhwd = _ImageDimensions(image, rank: 4);

assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(offset_height),
assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(offset_height),
tf.constant(0)), typeof(ValueError),
"offset_height must be >= 0.");
assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(offset_width),
"offset_height must be >= 0."));
assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(offset_width),
tf.constant(0)), typeof(ValueError),
"offset_width must be >= 0.");
assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0),
"offset_width must be >= 0."));
assert_ops.Add(_assert(check_ops.assert_less(tf.constant(0),
tf.constant(target_width)), typeof(ValueError),
"target_width must be > 0.");
assert_ops[assert_ops.Length] = _assert(check_ops.assert_less(tf.constant(0),
"target_width must be > 0."));
assert_ops.Add(_assert(check_ops.assert_less(tf.constant(0),
tf.constant(target_height)), typeof(ValueError),
"target_height must be > 0.");
assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(bhwd[2]),
"target_height must be > 0."));
assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(bhwd[2]),
tf.constant(target_width + offset_width)),
typeof(ValueError),
"width must be >= target + offset.");
assert_ops[assert_ops.Length] = _assert(check_ops.assert_greater_equal(tf.constant(bhwd[1]),
"width must be >= target + offset."));
assert_ops.Add(_assert(check_ops.assert_greater_equal(tf.constant(bhwd[1]),
tf.constant(target_height + offset_height)),
typeof(ValueError),
"height must be >= target + offset.");
image = control_flow_ops.with_dependencies(assert_ops, image);
"height must be >= target + offset."));
image = control_flow_ops.with_dependencies(assert_ops.ToArray(), image);

var cropped = array_ops.slice(
image, array_ops.stack(new[] { 0, offset_height, offset_width, 0 }),
@@ -575,12 +575,16 @@ or rank = 4. Had rank = {0}", rank));

Shape cropped_shape_result()
{
long[] i_remnants = { };
long[] i_remnants = new long[4];
int idx = 0;
foreach (var i in new[] { bhwd[0], target_height, target_width, bhwd[3] })
{
if (_is_tensor(i))
return null;
i_remnants[idx] = -1;
else
i_remnants[i_remnants.Length] = i;
i_remnants[idx] = i;
idx++;
}
return new Shape(i_remnants);
};
var cropped_shape = cropped_shape_result();


Loading…
Cancel
Save