@@ -43,6 +43,7 @@ namespace Tensorflow.Contexts | |||
public SafeContextHandle Handle => _handle; | |||
int? _seed; | |||
Random _rng; | |||
public Context() | |||
{ | |||
@@ -74,11 +75,23 @@ namespace Tensorflow.Contexts | |||
} | |||
public void set_global_seed(int? seed) | |||
=> _seed = seed; | |||
{ | |||
_seed = seed; | |||
if (seed.HasValue) | |||
_rng = new Random(seed.Value); | |||
else | |||
_rng = null; | |||
// Also clear the kernel cache, to reset any existing seeds | |||
if (_handle != null) | |||
c_api.TFE_ContextClearCaches(_handle); | |||
} | |||
public int? global_seed() | |||
=> _seed; | |||
public int? internal_operation_seed() | |||
=> _rng?.Next(0, int.MaxValue); | |||
public void start_step() | |||
=> c_api.TFE_ContextStartStep(_handle); | |||
@@ -94,7 +107,7 @@ namespace Tensorflow.Contexts | |||
{ | |||
if(context_switches.Count() == 0) | |||
tf.enable_eager_execution(); | |||
return context_switches.Current().EagerMode; | |||
} | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -21,6 +22,7 @@ namespace Tensorflow | |||
public class random_seed | |||
{ | |||
private static int DEFAULT_GRAPH_SEED = 87654321; | |||
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); | |||
public static (int?, int?) get_seed(int? op_seed = null) | |||
{ | |||
@@ -32,7 +34,20 @@ namespace Tensorflow | |||
global_seed = ops.get_default_graph().seed; | |||
if (global_seed.HasValue) | |||
{ | |||
if (!op_seed.HasValue) | |||
if (tf.executing_eagerly()) | |||
op_seed = tf.Context.internal_operation_seed(); | |||
else | |||
{ | |||
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) | |||
seed = 0; | |||
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; | |||
op_seed = seed; | |||
} | |||
return (global_seed, op_seed); | |||
} | |||
if (op_seed.HasValue) | |||
return (DEFAULT_GRAPH_SEED, op_seed); | |||
@@ -131,7 +131,9 @@ namespace Tensorflow | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"RandomShuffle", name, | |||
null, | |||
value, seed, seed2); | |||
value, | |||
"seed", seed, | |||
"seed2", seed2); | |||
return results[0]; | |||
} | |||
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
/// Test the function of setting random seed | |||
/// This will help regenerate the same result | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
[TestMethod] | |||
public void TFRandomSeedTest() | |||
{ | |||
var initValue = np.arange(6).reshape(3, 2); | |||
@@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
/// <summary> | |||
/// compare to Test above, seed is also added in params | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
[TestMethod] | |||
public void TFRandomSeedTest2() | |||
{ | |||
var initValue = np.arange(6).reshape(3, 2); | |||
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
/// <summary> | |||
/// This part we use funcs in tf.random rather than only tf | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
[TestMethod] | |||
public void TFRandomRaodomSeedTest() | |||
{ | |||
tf.set_random_seed(1234); | |||
@@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
/// <summary> | |||
/// compare to Test above, seed is also added in params | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
[TestMethod] | |||
public void TFRandomRaodomSeedTest2() | |||
{ | |||
tf.set_random_seed(1234); | |||