diff --git a/README.md b/README.md
index 42cd3ec4..fd92c400 100644
--- a/README.md
+++ b/README.md
@@ -66,7 +66,7 @@ using(var sess = tf.Session())
Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html).
-More examples:
+### More examples:
* [Hello World](test/TensorFlowNET.Examples/HelloWorld.cs)
* [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs)
diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 5b80313c..32b75807 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -7,7 +7,7 @@ namespace Tensorflow
public static partial class tf
{
public static IInitializer zeros_initializer => new Zeros();
- public static IInitializer glorot_uniform => new GlorotUniform();
+ public static IInitializer glorot_uniform_initializer => new GlorotUniform();
public static variable_scope variable_scope(string name_or_scope,
string default_name = null,
diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
index 8cf40fd1..354177f9 100644
--- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
@@ -24,10 +24,34 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;
- var _op = _op_def_lib._apply_op_helper("RandomStandardNormal", name: name,
+ var _op = _op_def_lib._apply_op_helper("RandomStandardNormal",
+ name: name,
args: new { shape, dtype, seed, seed2 });
return _op.outputs[0];
}
+
+ ///
+ /// Outputs random values from a uniform distribution.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
+ {
+ if (!seed.HasValue)
+ seed = 0;
+ if (!seed2.HasValue)
+ seed2 = 0;
+
+ var _op = _op_def_lib._apply_op_helper("RandomUniform",
+ name: name,
+ args: new { shape, dtype, seed, seed2});
+
+ return _op.outputs[0];
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
index eae27c58..00f1aa00 100644
--- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
@@ -56,7 +56,11 @@ namespace Tensorflow
return with(new ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{
name = scope;
- return null;
+ var tensorShape = _ShapeTensor(shape);
+ var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
+ var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
+ var rnd = gen_random_ops.random_uniform(tensorShape, dtype);
+ return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name);
});
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index e2dd55a5..76db062e 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -77,6 +77,11 @@ namespace Tensorflow
return null;
}
+ public TensorShape getShape()
+ {
+ return tensor_util.to_shape(shape);
+ }
+
///
/// number of dimensions
/// 0 Scalar (magnitude only)
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index f0e4e721..ea7069e4 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -1,4 +1,6 @@
-using System;
+using Google.Protobuf;
+using Google.Protobuf.Collections;
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -99,7 +101,7 @@ namespace Tensorflow
if (initial_value is null)
throw new ValueError("initial_value must be specified.");
- var init_from_fn = false;
+ var init_from_fn = initial_value.GetType().Name == "Func`1";
if(collections == null)
{
@@ -115,12 +117,27 @@ namespace Tensorflow
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
ops.init_scope();
- var values = init_from_fn ? new List