diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 99a01b8b..76e77986 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -89,6 +89,11 @@ namespace Tensorflow return shape; } + public static TensorShape to_shape(long[] dims) + { + return new TensorShape(dims.Select(x => (int)x).ToArray()); + } + public static TensorShape as_shape(this IShape shape, int[] dims) { return new TensorShape(dims); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 075123f1..65c34370 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -18,6 +18,7 @@ namespace Tensorflow public Operation initializer => _initializer_op; public Operation op => _variable.op; public TF_DataType dtype => _variable.dtype; + public TensorShape shape => tensor_util.to_shape(_variable.shape); public string name => _variable.name;