Browse Source

fix Shape.Equals

tags/TimeSeries
dss Haiping 3 years ago
parent
commit
67a7fc57e0
2 changed files with 9 additions and 3 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  2. +5
    -3
      src/TensorFlowNET.Keras/BackendImpl.cs

+ 4
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -100,6 +100,10 @@ namespace Tensorflow.NumPy
if (shape.ndim != shape2.Length) if (shape.ndim != shape2.Length)
return false; return false;
return Enumerable.SequenceEqual(shape.dims, shape2); return Enumerable.SequenceEqual(shape.dims, shape2);
case int[] shape3:
if (shape.ndim != shape3.Length)
return false;
return Enumerable.SequenceEqual(shape.as_int_list(), shape3);
default: default:
return false; return false;
} }


+ 5
- 3
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -347,19 +347,21 @@ namespace Tensorflow.Keras
string data_format = null, string data_format = null,
Shape dilation_rate = null) Shape dilation_rate = null)
{ {
/*
var force_transpose = false; var force_transpose = false;
if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 })) if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
force_transpose = true; force_transpose = true;
// x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
*/
var tf_data_format = "NHWC"; var tf_data_format = "NHWC";
padding = padding.ToUpper(); padding = padding.ToUpper();
strides = new Shape(1, strides[0], strides[1], 1); strides = new Shape(1, strides[0], strides[1], 1);
if (dilation_rate.Equals(new long[] { 1, 1 }))
if (dilation_rate.Equals(new[] { 1, 1 }))
x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides, x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
padding: padding, padding: padding,
data_format: tf_data_format); data_format: tf_data_format);
else else
throw new NotImplementedException("");
throw new NotImplementedException("dilation_rate other than [1,1] is not yet supported");


return x; return x;
} }


Loading…
Cancel
Save