@@ -155,7 +155,7 @@ namespace Tensorflow | |||
switch (a) | |||
{ | |||
case Tensor tensor: | |||
return tensor.shape[0]; | |||
return (int)tensor.shape[0]; | |||
case Tensors arr: | |||
return arr.Length; | |||
case Array arr: | |||
@@ -10,7 +10,7 @@ namespace Tensorflow.Framework | |||
{ | |||
public static void assert_is_compatible_with(this Tensor self, Tensor other) | |||
{ | |||
if (!self.is_compatible_with(other)) | |||
/*if (!self.is_compatible_with(other)) | |||
{ | |||
var selfDim = self.shape | |||
.Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) | |||
@@ -21,7 +21,7 @@ namespace Tensorflow.Framework | |||
.Replace(", }", "}"); | |||
throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible"); | |||
} | |||
}*/ | |||
} | |||
public static bool is_compatible_with(this Tensor self, Tensor other) | |||
@@ -27,10 +27,10 @@ namespace Tensorflow.Gradients | |||
{ | |||
var grad = grads[0]; | |||
var image = op.inputs[0]; | |||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||
var shape = new TensorShape(image.shape.dims.Skip(1).Take(2).ToArray()); | |||
Tensor image_shape = null; | |||
if (shape.is_fully_defined()) | |||
image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray()); | |||
image_shape = constant_op.constant(image.shape.dims.Skip(1).Take(2).ToArray()); | |||
else | |||
image_shape = array_ops.shape(image)["1:3"]; | |||
@@ -195,7 +195,7 @@ namespace Tensorflow.Gradients | |||
if (op is EagerOperation op_eager && | |||
op_eager.SkipInputIndices.Contains(1) && | |||
y.NDims == 0) | |||
y.ndim == 0) | |||
{ | |||
return new Tensor[] | |||
{ | |||
@@ -759,7 +759,7 @@ namespace Tensorflow.Gradients | |||
if (op is EagerOperation op_eager && | |||
op_eager.SkipInputIndices.Contains(1) && | |||
y.NDims == 0) | |||
y.ndim == 0) | |||
{ | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -22,6 +23,12 @@ namespace Tensorflow | |||
public static implicit operator Axis(int[] axis) | |||
=> new Axis(axis); | |||
public static implicit operator Axis(long[] shape) | |||
=> new Axis(shape.Select(x => (int)x).ToArray()); | |||
public static implicit operator Axis(Shape shape) | |||
=> new Axis(shape.dims.Select(x => (int)x).ToArray()); | |||
} | |||
} | |||
@@ -11,8 +11,9 @@ namespace Tensorflow.NumPy | |||
Tensor _tensor; | |||
public TF_DataType dtype => _tensor.dtype; | |||
public ulong size => _tensor.size; | |||
public ulong dtypesize => _tensor.itemsize; | |||
public int ndim => _tensor.NDims; | |||
public ulong dtypesize => _tensor.dtypesize; | |||
public ulong bytesize => _tensor.bytesize; | |||
public int ndim => _tensor.ndim; | |||
public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray(); | |||
public Shape shape => _tensor.shape; | |||
public IntPtr data => _tensor.TensorDataPointer; | |||
@@ -48,6 +48,12 @@ namespace Tensorflow | |||
public static implicit operator Shape((long, long, long, long) dims) | |||
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
public static implicit operator int[](Shape shape) | |||
=> shape.dims.Select(x => (int)x).ToArray(); | |||
public static implicit operator long[](Shape shape) | |||
=> shape.dims; | |||
public bool IsEmpty => size == 0; | |||
public bool IsScalar => ndim == 0; | |||
@@ -55,6 +61,8 @@ namespace Tensorflow | |||
public static Shape Scalar | |||
=> new Shape(new long[0]); | |||
public long this[int n] => dims[n]; | |||
/// <summary> | |||
/// Returns the size this shape represents. | |||
/// </summary> | |||
@@ -81,6 +89,25 @@ namespace Tensorflow | |||
} | |||
} | |||
public bool is_fully_defined() | |||
{ | |||
return ndim > -1 && dims != null && dims.Count(x => x < 1) == 0; | |||
} | |||
public bool is_compatible_with(TensorShape shape2) | |||
{ | |||
if (dims != null && shape2.dims != null) | |||
{ | |||
if (dims.Contains(-1) || shape2.dims.Contains(-1)) | |||
return true; | |||
if (size != (ulong)shape2.size) | |||
return false; | |||
} | |||
return true; | |||
} | |||
public override bool Equals(object obj) | |||
{ | |||
if(obj is Shape shape) | |||
@@ -92,7 +92,7 @@ namespace Tensorflow | |||
public Tensor _batch_shape() | |||
{ | |||
return array_ops.broadcast_static_shape(new Tensor(_loc.shape), new Tensor(_scale.shape)); | |||
return array_ops.broadcast_static_shape(new Tensor(_loc.shape.dims), new Tensor(_scale.shape.dims)); | |||
} | |||
protected override Tensor _log_prob(Tensor x) | |||
@@ -27,9 +27,9 @@ namespace Tensorflow.Operations | |||
{ | |||
var p = prefix; | |||
var p_static = tensor_util.constant_value(prefix); | |||
if (p.NDims == 0) | |||
if (p.ndim == 0) | |||
p = array_ops.expand_dims(p, 0); | |||
else if (p.NDims != 1) | |||
else if (p.ndim != 1) | |||
throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); | |||
var s_tensor_shape = new TensorShape(suffix); | |||
@@ -186,7 +186,7 @@ namespace Tensorflow | |||
private static Tensor _constant_if_small(int value, Tensor shape) | |||
{ | |||
return shape < 1000L; | |||
return shape < 1000UL; | |||
} | |||
private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) | |||
@@ -330,7 +330,7 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
var input_tensor = ops.convert_to_tensor(inputs); | |||
return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name); | |||
return constant_op.constant(input_tensor.ndim, dtype: tf.int32, name: name); | |||
}); | |||
} | |||
@@ -340,7 +340,7 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
var input_tensor = ops.convert_to_tensor(input); | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
var input_shape = input_tensor.shape; | |||
if (optimize && input_shape.ndim > 0) | |||
return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); | |||
else | |||
@@ -364,7 +364,7 @@ namespace Tensorflow | |||
tensor = ops.convert_to_tensor(tensor, name: "tensor"); | |||
// is_fully_defined return unexpected value. | |||
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) | |||
if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT) | |||
{ | |||
} | |||
@@ -589,9 +589,9 @@ namespace Tensorflow | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
var input_shape = input.TensorShape; | |||
if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) | |||
if (optimize && input.ndim > -1 && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input.shape).astype(out_type.as_system_dtype()); | |||
var nd = np.array(input.shape.dims).astype(out_type.as_system_dtype()); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -607,7 +607,7 @@ namespace Tensorflow | |||
name = scope; | |||
var input_tensor = ops.convert_to_tensor(input); | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
var input_shape = input_tensor.shape; | |||
if (optimize) | |||
{ | |||
if (input_shape.is_fully_defined()) | |||
@@ -633,7 +633,7 @@ namespace Tensorflow | |||
tensor = ops.convert_to_tensor(tensor, name: "tensor"); | |||
// is_fully_defined return unexpected value. | |||
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) | |||
if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT) | |||
{ | |||
} | |||
@@ -933,7 +933,7 @@ namespace Tensorflow | |||
string name = "split") | |||
{ | |||
if (num == -1) | |||
num = size_splits.shape[0]; | |||
num = (int)size_splits.shape[0]; | |||
return gen_array_ops.split_v(value, size_splits, axis, num, name: name); | |||
} | |||
@@ -91,7 +91,7 @@ namespace Tensorflow | |||
elem.dtype, | |||
size: tf.constant(n), | |||
dynamic_size: false, | |||
element_shape: elem.shape.Skip(1).ToArray(), | |||
element_shape: elem.shape.dims.Skip(1).ToArray(), | |||
infer_shape: true)).ToList(); | |||
for (int index = 0; index < elems_ta.Count; index++) | |||
@@ -341,14 +341,14 @@ or rank = 4. Had rank = {0}", rank)); | |||
{ | |||
h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1] | |||
w = _get_dim(image, 1); | |||
d = image.shape[3]; | |||
d = (int)image.shape[3]; | |||
} | |||
else | |||
{ | |||
bs = image.shape[0]; | |||
bs = (int)image.shape[0]; | |||
h = _get_dim(image, 1); | |||
w = _get_dim(image, 2); | |||
d = image.shape[3]; | |||
d = (int)image.shape[3]; | |||
} | |||
object hd, bbox_h_start; | |||
@@ -1115,7 +1115,7 @@ new_height, new_width"); | |||
array_ops.expand_dims(tf.constant(3), 0)); | |||
var multiples = array_ops.concat(new Tensor[] { shape_list }, 0); | |||
var rgb = array_ops.tile(images, multiples, name: name); | |||
int[] rgb_temp = images.shape.Take(images.shape.Length - 1).ToArray(); | |||
int[] rgb_temp = images.shape.dims.Take(images.shape.ndim - 1).Select(x => (int)x).ToArray(); | |||
rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3)); | |||
return rgb; | |||
}); | |||
@@ -1459,7 +1459,7 @@ new_height, new_width"); | |||
// shape takes an int, python code passes size, a Tensor. NDims is the only int type | |||
// i could think of a Tensor having. it might be incorrect tho, so keep that in mind. | |||
return array_ops.reshape(g, shape: new int[] { size.NDims, size.NDims, 1, 1 }); | |||
return array_ops.reshape(g, shape: new int[] { size.ndim, size.ndim, 1, 1 }); | |||
} | |||
internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f, | |||
@@ -1487,7 +1487,7 @@ new_height, new_width"); | |||
img1 = array_ops.identity(img1); | |||
var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor); | |||
kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 })); | |||
kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, (int)shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 })); | |||
float compensation = 1.0f; | |||
@@ -1503,8 +1503,8 @@ new_height, new_width"); | |||
(Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); | |||
var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); | |||
var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims)); | |||
cs = math_ops.reduce_mean(cs, new(axes.dims)); | |||
var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims); | |||
cs = math_ops.reduce_mean(cs, axes.dims); | |||
return (ssim_val, cs); | |||
} | |||
@@ -1685,7 +1685,7 @@ new_height, new_width"); | |||
var kernels_tf = constant_op.constant(kernels, dtype: image.dtype); | |||
kernels_tf = array_ops.tile( | |||
kernels_tf, new Tensor(new int[] { 1, 1, image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters"); | |||
kernels_tf, new Tensor(new int[] { 1, 1, (int)image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters"); | |||
var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } }; | |||
var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect"); | |||
@@ -1966,8 +1966,8 @@ new_height, new_width"); | |||
Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices; | |||
using (ops.name_scope("sort_scores_and_boxes")) | |||
{ | |||
batch_size = array_ops.shape(boxes).dims[0]; | |||
num_boxes = array_ops.shape(boxes).dims[1]; | |||
batch_size = (int)array_ops.shape(boxes).dims[0]; | |||
num_boxes = (int)array_ops.shape(boxes).dims[1]; | |||
sorted_scores_indices = null; /*sort_ops.argsort( | |||
scores, axis: 1, direction: "DESCENDING); */ | |||
index_offsets = math_ops.range(batch_size) * num_boxes; | |||
@@ -178,7 +178,7 @@ namespace Tensorflow | |||
logits = ops.convert_to_tensor(logits); | |||
var shape = logits.shape; | |||
bool is_last_dim = dim == -1 || dim == shape.Length - 1; | |||
bool is_last_dim = dim == -1 || dim == shape.ndim - 1; | |||
if (is_last_dim) | |||
return compute_op(logits, name); | |||
@@ -37,7 +37,7 @@ namespace Tensorflow | |||
{ | |||
get | |||
{ | |||
return _row_splits.shape[0] - 1; | |||
return (int)_row_splits.shape[0] - 1; | |||
} | |||
} | |||
@@ -145,6 +145,10 @@ namespace Tensorflow | |||
byte[,] val => InitTensor(val, shape, dtype), | |||
byte[,,] val => InitTensor(val, shape, dtype), | |||
byte[,,,] val => InitTensor(val, shape, dtype), | |||
short[] val => InitTensor(val, shape, dtype), | |||
short[,] val => InitTensor(val, shape, dtype), | |||
short[,,] val => InitTensor(val, shape, dtype), | |||
short[,,,] val => InitTensor(val, shape, dtype), | |||
int[] val => InitTensor(val, shape, dtype), | |||
int[,] val => InitTensor(val, shape, dtype), | |||
int[,,] val => InitTensor(val, shape, dtype), | |||
@@ -153,6 +157,10 @@ namespace Tensorflow | |||
long[,] val => InitTensor(val, shape, dtype), | |||
long[,,] val => InitTensor(val, shape, dtype), | |||
long[,,,] val => InitTensor(val, shape, dtype), | |||
ulong[] val => InitTensor(val, shape, dtype), | |||
ulong[,] val => InitTensor(val, shape, dtype), | |||
ulong[,,] val => InitTensor(val, shape, dtype), | |||
ulong[,,,] val => InitTensor(val, shape, dtype), | |||
float[] val => InitTensor(val, shape, dtype), | |||
float[,] val => InitTensor(val, shape, dtype), | |||
float[,,] val => InitTensor(val, shape, dtype), | |||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
if (typeof(T).as_tf_dtype() != dtype) | |||
throw new ArrayTypeMismatchException($"dtype {dtype} mismatch."); | |||
if (NDims == 0 && size == 1) //is it a scalar? | |||
if (ndim == 0 && size == 1) //is it a scalar? | |||
{ | |||
unsafe | |||
{ | |||
@@ -28,7 +28,7 @@ namespace Tensorflow | |||
//types match, no need to perform cast | |||
var ret = new T[size]; | |||
var len = (long)(size * itemsize); | |||
var len = (long)(size * dtypesize); | |||
var src = (T*)buffer; | |||
fixed (T* dst = ret) | |||
@@ -72,17 +72,17 @@ namespace Tensorflow | |||
/// </summary> | |||
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | |||
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | |||
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | |||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | |||
public ulong dtypesize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | |||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dtypesize; | |||
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
public int NDims => rank; | |||
public int ndim => rank; | |||
/// <summary> | |||
/// The name of the device on which this tensor will be produced, or null. | |||
/// </summary> | |||
public virtual string Device => op.Device; | |||
public int[] dims => shape; | |||
public long[] dims => shape.dims; | |||
/// <summary> | |||
/// Used for keep other pointer when do implicit operating | |||
@@ -107,7 +107,7 @@ namespace Tensorflow | |||
/// Returns the shape of a tensor. | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||
public int[] shape | |||
public Shape shape | |||
{ | |||
get | |||
{ | |||
@@ -123,7 +123,7 @@ namespace Tensorflow | |||
dims[i] = c_api.TF_Dim(_handle, i); | |||
} | |||
return dims.Select(x => ((IConvertible)x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||
return dims; | |||
} | |||
set | |||
@@ -131,7 +131,7 @@ namespace Tensorflow | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||
else | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.Status.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
} | |||
@@ -139,10 +139,10 @@ namespace Tensorflow | |||
public int[] _shape_tuple() | |||
{ | |||
return rank < 0 ? null : shape; | |||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | |||
} | |||
public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | |||
public TensorShape TensorShape => rank < 0 ? new TensorShape() : shape; | |||
/// <summary> | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
@@ -109,7 +109,8 @@ namespace Tensorflow | |||
var length = shape.size * (ulong)dtype.get_datatype_size(); | |||
var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length); | |||
var tensor = TF_TensorData(handle); | |||
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); | |||
if (tensor != IntPtr.Zero) | |||
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); | |||
return handle; | |||
} | |||
@@ -124,6 +124,8 @@ namespace Tensorflow | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case long val: | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case ulong val: | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case float val: | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case double val: | |||
@@ -146,7 +148,7 @@ namespace Tensorflow | |||
if (shape == null) | |||
return t; | |||
if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) | |||
if (t.shape.dims.SequenceEqual(shape.dims)) | |||
return t; | |||
if (verify_shape) | |||
@@ -127,7 +127,7 @@ namespace Tensorflow | |||
} | |||
else if (values is Tensor tensor && tensor.IsReferencedByNDArray) | |||
{ | |||
var len = tensor.itemsize * tensor.size; | |||
var len = tensor.dtypesize * tensor.size; | |||
byte[] bytes = tensor.BufferToArray(); | |||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||
} | |||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||
var restored_tensor = restored_tensors[0]; | |||
return gen_state_ops.assign(op, | |||
restored_tensor, | |||
validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined()); | |||
validate_shape: restored_shapes == null && op.shape.is_fully_defined()); | |||
} | |||
} | |||
} |
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
public Operation Op => _variable.op; | |||
public TF_DataType dtype => _variable.dtype; | |||
public TensorShape shape => tensor_util.to_shape(_variable.shape); | |||
public TensorShape shape => _variable.shape; | |||
public string Device => ""; | |||
public string Name => _variable.name; | |||
@@ -297,8 +297,8 @@ namespace Tensorflow.Keras | |||
// x = permute_dimensions(x, [0, 3, 1, 2]); | |||
throw new NotImplementedException(""); | |||
int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor; | |||
int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor; | |||
int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor; | |||
int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor; | |||
TensorShape output_shape = data_format == "channels_first" ? | |||
(-1, -1, new_height, new_width) : (-1, new_height, new_width, -1); | |||
@@ -316,7 +316,7 @@ namespace Tensorflow.Keras | |||
{ | |||
if(axis < 0) | |||
{ | |||
var rank = tensors[0].NDims; | |||
var rank = tensors[0].ndim; | |||
if (rank > -1) | |||
axis += rank; | |||
else | |||
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
this.args = args; | |||
_process_tensorlike(); | |||
num_samples = args.X.shape[0]; | |||
num_samples = (int)args.X.shape[0]; | |||
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | |||
_batch_size = batch_size; | |||
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); | |||
@@ -63,8 +63,8 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var y_t_rank = y_t.rank; | |||
var y_p_rank = y_p.rank; | |||
var y_t_last_dim = y_t.shape[y_t.shape.Length - 1]; | |||
var y_p_last_dim = y_p.shape[y_p.shape.Length - 1]; | |||
var y_t_last_dim = y_t.shape[y_t.shape.ndim - 1]; | |||
var y_p_last_dim = y_p.shape[y_p.shape.ndim - 1]; | |||
bool is_binary = y_p_last_dim == 1; | |||
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; | |||
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Metrics | |||
var y_true_rank = y_true.TensorShape.ndim; | |||
// If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) | |||
if (y_true_rank != -1 && y_pred_rank != -1 | |||
&& y_true.shape.Length == y_pred.shape.Length) | |||
&& y_true.shape.ndim == y_pred.shape.ndim) | |||
y_true = array_ops.squeeze(y_true, axis: new[] { -1 }); | |||
y_pred = math_ops.argmax(y_pred, -1); | |||
@@ -212,13 +212,13 @@ namespace Tensorflow.Keras | |||
string data_format = "channels_last") | |||
{ | |||
var input_shape = inputs.shape; | |||
if (inputs.shape.Length == 0) | |||
if (inputs.shape.ndim == 0) | |||
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); | |||
var premutation = new List<int>() { 0 }; | |||
if (data_format == "channels_first" && inputs.NDims > 1) | |||
if (data_format == "channels_first" && inputs.ndim > 1) | |||
{ | |||
premutation.AddRange(Binding.range(2, inputs.NDims)); | |||
premutation.AddRange(Binding.range(2, inputs.ndim)); | |||
premutation.Add(1); | |||
inputs = array_ops.transpose(inputs, premutation.ToArray()); | |||
} | |||
@@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest.Sessions | |||
csession.Run(s); | |||
Tensor outTensor = csession.output_tensor(0); | |||
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
EXPECT_EQ(0, outTensor.NDims); | |||
EXPECT_EQ(0, outTensor.ndim); | |||
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
var output_contents = outTensor.ToArray<int>(); | |||
EXPECT_EQ(3 + 2, output_contents[0]); | |||
@@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest.Sessions | |||
outTensor = csession.output_tensor(0); | |||
ASSERT_TRUE(outTensor != IntPtr.Zero); | |||
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
EXPECT_EQ(0, outTensor.NDims); // scalar | |||
EXPECT_EQ(0, outTensor.ndim); // scalar | |||
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
output_contents = outTensor.ToArray<int>(); | |||
EXPECT_EQ(-(7 + 2), output_contents[0]); | |||
@@ -66,7 +66,7 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
long[] dims = { 2, 3 }; | |||
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||
EXPECT_EQ(2, t.NDims); | |||
EXPECT_EQ(2, t.ndim); | |||
EXPECT_EQ((int)dims[0], t.shape[0]); | |||
EXPECT_EQ(num_bytes, t.bytesize); | |||
t.Dispose(); | |||
@@ -126,7 +126,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
var x = tf.constant(new[,] { { 1, 2 } }); | |||
var neg_x = tf.negative(x); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>())); | |||
} | |||
@@ -145,7 +145,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var tensor = tf.constant(nd); | |||
var data = tensor.numpy().ToArray<int>(); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | |||
} | |||
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var b = tf.Variable(-0.73f, name: "bias"); | |||
using var g = tf.GradientTape(); | |||
var pred = W * X + b; | |||
var test = tf.slice(pred, new[] { 0 }, pred.shape); | |||
var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape); | |||
var gradients = g.gradient(test, (W, b)); | |||
Assert.AreEqual((float)gradients.Item1, 0f); | |||
Assert.AreEqual((float)gradients.Item2, 10f); | |||
@@ -85,14 +85,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
{ { 1 }, { 2 }, { 3 } }, | |||
{ { 4 }, { 5 }, { 6 } } | |||
})); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims)); | |||
var b = tf.constant(new[, ,] | |||
{ | |||
{ { 1 }, { 2 }, { 3 } }, | |||
{ { 4 }, { 5 }, { 6 } } | |||
}); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims)); | |||
} | |||
[TestMethod] | |||
@@ -103,7 +103,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); | |||
var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); | |||
} | |||
[TestMethod] | |||
@@ -114,7 +114,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); | |||
var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); | |||
} | |||
[TestMethod] | |||
@@ -128,7 +128,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var splitValue = tf.split(value, 3, axis: 0); | |||
Assert.AreEqual(3, splitValue.Length); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims)); | |||
} | |||
#region ones/zeros like | |||