diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index bfc203a6..9fbf3924 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -71,7 +71,7 @@ namespace Tensorflow string name = null) { if (dtype.is_integer()) - return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, dtype, seed, name); + return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, seed, name); else return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); } diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs index bed327a3..4a3a309e 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs @@ -23,8 +23,18 @@ namespace Tensorflow.NumPy public NDArray rand(params int[] shape) => throw new NotImplementedException(""); + [AutoNumPy] public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32) - => throw new NotImplementedException(""); + { + if(high == null) + { + high = low; + low = 0; + } + size = size ?? Shape.Scalar; + var tensor = random_ops.random_uniform_int(shape: size, minval: low, maxval: (int)high); + return new NDArray(tensor); + } public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Operations/random_ops.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs index 9f823c96..dddcc05a 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.cs @@ -94,7 +94,6 @@ namespace Tensorflow public static Tensor random_uniform_int(int[] shape, int minval = 0, int maxval = 1, - TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) { @@ -103,8 +102,8 @@ namespace Tensorflow name = scope; var (seed1, seed2) = random_seed.get_seed(seed); var tensorShape = tensor_util.shape_tensor(shape); - var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); - var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); + var minTensor = ops.convert_to_tensor(minval, name: "min"); + var maxTensor = ops.convert_to_tensor(maxval, name: "max"); return gen_random_ops.random_uniform_int(tensorShape, minTensor, maxTensor, seed: seed1, seed2: seed2); }); }