diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index bd74c8fd..d5656e87 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -93,7 +93,12 @@ namespace Tensorflow => random_ops.random_shuffle(value, seed: seed, name: name); public void set_random_seed(int seed) - => ops.get_default_graph().seed = seed; + { + if (executing_eagerly()) + Context.set_global_seed(seed); + else + ops.get_default_graph().seed = seed; + } public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 4d894702..95f75a94 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -42,6 +42,9 @@ namespace Tensorflow.Contexts SafeContextHandle _handle; public SafeContextHandle Handle => _handle; + int? _seed; + Random _rng; + public Context() { _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; @@ -71,6 +74,24 @@ namespace Tensorflow.Contexts initialized = true; } + public void set_global_seed(int? seed) + { + _seed = seed; + if (seed.HasValue) + _rng = new Random(seed.Value); + else + _rng = null; + // Also clear the kernel cache, to reset any existing seeds + if (_handle != null) + c_api.TFE_ContextClearCaches(_handle); + } + + public int? global_seed() + => _seed; + + public int? internal_operation_seed() + => _rng?.Next(0, int.MaxValue); + public void start_step() => c_api.TFE_ContextStartStep(_handle); @@ -86,7 +107,7 @@ namespace Tensorflow.Contexts { if(context_switches.Count() == 0) tf.enable_eager_execution(); - + return context_switches.Current().EagerMode; } diff --git a/src/TensorFlowNET.Core/Framework/random_seed.py.cs b/src/TensorFlowNET.Core/Framework/random_seed.py.cs index e8af1993..8732c030 100644 --- a/src/TensorFlowNET.Core/Framework/random_seed.py.cs +++ b/src/TensorFlowNET.Core/Framework/random_seed.py.cs @@ -14,16 +14,43 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; +using static Tensorflow.Binding; + namespace Tensorflow { public class random_seed { private static int DEFAULT_GRAPH_SEED = 87654321; + private static Dictionary _graph_to_seed_dict = new Dictionary(); public static (int?, int?) get_seed(int? op_seed = null) { + int? global_seed; + + if (tf.executing_eagerly()) + global_seed = tf.Context.global_seed(); + else + global_seed = ops.get_default_graph().seed; + + if (global_seed.HasValue) + { + if (!op_seed.HasValue) + if (tf.executing_eagerly()) + op_seed = tf.Context.internal_operation_seed(); + else + { + if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) + seed = 0; + _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; + op_seed = seed; + } + + return (global_seed, op_seed); + } + if (op_seed.HasValue) - return (DEFAULT_GRAPH_SEED, 0); + return (DEFAULT_GRAPH_SEED, op_seed); else return (null, null); } diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index 19d774d6..12d41bf2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -79,7 +79,8 @@ namespace Tensorflow /// public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) - => tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value, seed, seed2)); + => tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value) + .SetAttributes(new { seed = seed, seed2 = seed2 })); /// /// Outputs random values from a truncated normal distribution. diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs index ed70fa35..b658586a 100644 --- a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics /// Test the function of setting random seed /// This will help regenerate the same result /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomSeedTest() { var initValue = np.arange(6).reshape(3, 2); @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics /// /// This part we use funcs in tf.random rather than only tf /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomRaodomSeedTest() { tf.set_random_seed(1234); diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 0a858ef9..e8e87840 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset var cardinality = dataset.dataset_cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); } + + [TestMethod] + public void Shuffle() + { + tf.set_random_seed(1234); + + var dataset = tf.data.Dataset.range(3); + var shuffled = dataset.shuffle(3); + + var zipped = tf.data.Dataset.zip(dataset, shuffled); + + bool allEqual = true; + foreach (var item in zipped) + { + if (item.Item1 != item.Item2) + allEqual = false; + } + + Assert.IsFalse(allEqual); + } } }