@@ -43,6 +43,7 @@ namespace Tensorflow.Contexts | |||||
public SafeContextHandle Handle => _handle; | public SafeContextHandle Handle => _handle; | ||||
int? _seed; | int? _seed; | ||||
Random _rng; | |||||
public Context() | public Context() | ||||
{ | { | ||||
@@ -74,11 +75,23 @@ namespace Tensorflow.Contexts | |||||
} | } | ||||
public void set_global_seed(int? seed) | public void set_global_seed(int? seed) | ||||
=> _seed = seed; | |||||
{ | |||||
_seed = seed; | |||||
if (seed.HasValue) | |||||
_rng = new Random(seed.Value); | |||||
else | |||||
_rng = null; | |||||
// Also clear the kernel cache, to reset any existing seeds | |||||
if (_handle != null) | |||||
c_api.TFE_ContextClearCaches(_handle); | |||||
} | |||||
public int? global_seed() | public int? global_seed() | ||||
=> _seed; | => _seed; | ||||
public int? internal_operation_seed() | |||||
=> _rng?.Next(0, int.MaxValue); | |||||
public void start_step() | public void start_step() | ||||
=> c_api.TFE_ContextStartStep(_handle); | => c_api.TFE_ContextStartStep(_handle); | ||||
@@ -94,7 +107,7 @@ namespace Tensorflow.Contexts | |||||
{ | { | ||||
if(context_switches.Count() == 0) | if(context_switches.Count() == 0) | ||||
tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
return context_switches.Current().EagerMode; | return context_switches.Current().EagerMode; | ||||
} | } | ||||
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -21,6 +22,7 @@ namespace Tensorflow | |||||
public class random_seed | public class random_seed | ||||
{ | { | ||||
private static int DEFAULT_GRAPH_SEED = 87654321; | private static int DEFAULT_GRAPH_SEED = 87654321; | ||||
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); | |||||
public static (int?, int?) get_seed(int? op_seed = null) | public static (int?, int?) get_seed(int? op_seed = null) | ||||
{ | { | ||||
@@ -32,7 +34,20 @@ namespace Tensorflow | |||||
global_seed = ops.get_default_graph().seed; | global_seed = ops.get_default_graph().seed; | ||||
if (global_seed.HasValue) | if (global_seed.HasValue) | ||||
{ | |||||
if (!op_seed.HasValue) | |||||
if (tf.executing_eagerly()) | |||||
op_seed = tf.Context.internal_operation_seed(); | |||||
else | |||||
{ | |||||
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) | |||||
seed = 0; | |||||
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; | |||||
op_seed = seed; | |||||
} | |||||
return (global_seed, op_seed); | return (global_seed, op_seed); | ||||
} | |||||
if (op_seed.HasValue) | if (op_seed.HasValue) | ||||
return (DEFAULT_GRAPH_SEED, op_seed); | return (DEFAULT_GRAPH_SEED, op_seed); | ||||
@@ -131,7 +131,9 @@ namespace Tensorflow | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
"RandomShuffle", name, | "RandomShuffle", name, | ||||
null, | null, | ||||
value, seed, seed2); | |||||
value, | |||||
"seed", seed, | |||||
"seed2", seed2); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// Test the function of setting random seed | /// Test the function of setting random seed | ||||
/// This will help regenerate the same result | /// This will help regenerate the same result | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomSeedTest() | public void TFRandomSeedTest() | ||||
{ | { | ||||
var initValue = np.arange(6).reshape(3, 2); | var initValue = np.arange(6).reshape(3, 2); | ||||
@@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// <summary> | /// <summary> | ||||
/// compare to Test above, seed is also added in params | /// compare to Test above, seed is also added in params | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomSeedTest2() | public void TFRandomSeedTest2() | ||||
{ | { | ||||
var initValue = np.arange(6).reshape(3, 2); | var initValue = np.arange(6).reshape(3, 2); | ||||
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// <summary> | /// <summary> | ||||
/// This part we use funcs in tf.random rather than only tf | /// This part we use funcs in tf.random rather than only tf | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomRaodomSeedTest() | public void TFRandomRaodomSeedTest() | ||||
{ | { | ||||
tf.set_random_seed(1234); | tf.set_random_seed(1234); | ||||
@@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// <summary> | /// <summary> | ||||
/// compare to Test above, seed is also added in params | /// compare to Test above, seed is also added in params | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomRaodomSeedTest2() | public void TFRandomRaodomSeedTest2() | ||||
{ | { | ||||
tf.set_random_seed(1234); | tf.set_random_seed(1234); | ||||