diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 04b994ed..d5490f28 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Contexts public SafeContextHandle Handle => _handle; int? _seed; + Random _rng; public Context() { @@ -74,11 +75,23 @@ namespace Tensorflow.Contexts } public void set_global_seed(int? seed) - => _seed = seed; + { + _seed = seed; + if (seed.HasValue) + _rng = new Random(seed.Value); + else + _rng = null; + // Also clear the kernel cache, to reset any existing seeds + if (_handle != null) + c_api.TFE_ContextClearCaches(_handle); + } public int? global_seed() => _seed; + public int? internal_operation_seed() + => _rng?.Next(0, int.MaxValue); + public void start_step() => c_api.TFE_ContextStartStep(_handle); @@ -94,7 +107,7 @@ namespace Tensorflow.Contexts { if(context_switches.Count() == 0) tf.enable_eager_execution(); - + return context_switches.Current().EagerMode; } diff --git a/src/TensorFlowNET.Core/Framework/random_seed.py.cs b/src/TensorFlowNET.Core/Framework/random_seed.py.cs index e51b7b81..8732c030 100644 --- a/src/TensorFlowNET.Core/Framework/random_seed.py.cs +++ b/src/TensorFlowNET.Core/Framework/random_seed.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; using static Tensorflow.Binding; namespace Tensorflow @@ -21,6 +22,7 @@ namespace Tensorflow public class random_seed { private static int DEFAULT_GRAPH_SEED = 87654321; + private static Dictionary _graph_to_seed_dict = new Dictionary(); public static (int?, int?) get_seed(int? op_seed = null) { @@ -32,7 +34,20 @@ namespace Tensorflow global_seed = ops.get_default_graph().seed; if (global_seed.HasValue) + { + if (!op_seed.HasValue) + if (tf.executing_eagerly()) + op_seed = tf.Context.internal_operation_seed(); + else + { + if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) + seed = 0; + _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; + op_seed = seed; + } + return (global_seed, op_seed); + } if (op_seed.HasValue) return (DEFAULT_GRAPH_SEED, op_seed); diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index 8528f4c4..6fddc47b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -131,7 +131,9 @@ namespace Tensorflow var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "RandomShuffle", name, null, - value, seed, seed2); + value, + "seed", seed, + "seed2", seed2); return results[0]; } diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs index ed70fa35..ded1d58c 100644 --- a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics /// Test the function of setting random seed /// This will help regenerate the same result /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomSeedTest() { var initValue = np.arange(6).reshape(3, 2); @@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics /// /// compare to Test above, seed is also added in params /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomSeedTest2() { var initValue = np.arange(6).reshape(3, 2); @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics /// /// This part we use funcs in tf.random rather than only tf /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomRaodomSeedTest() { tf.set_random_seed(1234); @@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics /// /// compare to Test above, seed is also added in params /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomRaodomSeedTest2() { tf.set_random_seed(1234);