Browse Source

fix Shape.Equals

tags/TimeSeries
dss Haiping 3 years ago
parent
commit
67a7fc57e0
2 changed files with 9 additions and 3 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  2. +5
    -3
      src/TensorFlowNET.Keras/BackendImpl.cs

+ 4
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -100,6 +100,10 @@ namespace Tensorflow.NumPy
if (shape.ndim != shape2.Length)
return false;
return Enumerable.SequenceEqual(shape.dims, shape2);
case int[] shape3:
if (shape.ndim != shape3.Length)
return false;
return Enumerable.SequenceEqual(shape.as_int_list(), shape3);
default:
return false;
}


+ 5
- 3
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -347,19 +347,21 @@ namespace Tensorflow.Keras
string data_format = null,
Shape dilation_rate = null)
{
/*
var force_transpose = false;
if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
force_transpose = true;
// x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
*/
var tf_data_format = "NHWC";
padding = padding.ToUpper();
strides = new Shape(1, strides[0], strides[1], 1);
if (dilation_rate.Equals(new long[] { 1, 1 }))
if (dilation_rate.Equals(new[] { 1, 1 }))
x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
padding: padding,
data_format: tf_data_format);
else
throw new NotImplementedException("");
throw new NotImplementedException("dilation_rate other than [1,1] is not yet supported");

return x;
}


Loading…
Cancel
Save