diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index bd74c8fd..d5656e87 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -93,7 +93,12 @@ namespace Tensorflow => random_ops.random_shuffle(value, seed: seed, name: name); public void set_random_seed(int seed) - => ops.get_default_graph().seed = seed; + { + if (executing_eagerly()) + Context.set_global_seed(seed); + else + ops.get_default_graph().seed = seed; + } public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index be4b56b2..04b994ed 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -42,6 +42,8 @@ namespace Tensorflow.Contexts SafeContextHandle _handle; public SafeContextHandle Handle => _handle; + int? _seed; + public Context() { _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; @@ -71,6 +73,12 @@ namespace Tensorflow.Contexts initialized = true; } + public void set_global_seed(int? seed) + => _seed = seed; + + public int? global_seed() + => _seed; + public void start_step() => c_api.TFE_ContextStartStep(_handle); diff --git a/src/TensorFlowNET.Core/Framework/random_seed.py.cs b/src/TensorFlowNET.Core/Framework/random_seed.py.cs index e8af1993..e51b7b81 100644 --- a/src/TensorFlowNET.Core/Framework/random_seed.py.cs +++ b/src/TensorFlowNET.Core/Framework/random_seed.py.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public class random_seed @@ -22,8 +24,18 @@ namespace Tensorflow public static (int?, int?) get_seed(int? op_seed = null) { + int? global_seed; + + if (tf.executing_eagerly()) + global_seed = tf.Context.global_seed(); + else + global_seed = ops.get_default_graph().seed; + + if (global_seed.HasValue) + return (global_seed, op_seed); + if (op_seed.HasValue) - return (DEFAULT_GRAPH_SEED, 0); + return (DEFAULT_GRAPH_SEED, op_seed); else return (null, null); }