Random seed handling more simular to pythontags/v0.40-tf2.4-tstring
@@ -93,7 +93,12 @@ namespace Tensorflow | |||||
=> random_ops.random_shuffle(value, seed: seed, name: name); | => random_ops.random_shuffle(value, seed: seed, name: name); | ||||
public void set_random_seed(int seed) | public void set_random_seed(int seed) | ||||
=> ops.get_default_graph().seed = seed; | |||||
{ | |||||
if (executing_eagerly()) | |||||
Context.set_global_seed(seed); | |||||
else | |||||
ops.get_default_graph().seed = seed; | |||||
} | |||||
public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | ||||
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | ||||
@@ -42,6 +42,9 @@ namespace Tensorflow.Contexts | |||||
SafeContextHandle _handle; | SafeContextHandle _handle; | ||||
public SafeContextHandle Handle => _handle; | public SafeContextHandle Handle => _handle; | ||||
int? _seed; | |||||
Random _rng; | |||||
public Context() | public Context() | ||||
{ | { | ||||
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; | _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; | ||||
@@ -71,6 +74,24 @@ namespace Tensorflow.Contexts | |||||
initialized = true; | initialized = true; | ||||
} | } | ||||
public void set_global_seed(int? seed) | |||||
{ | |||||
_seed = seed; | |||||
if (seed.HasValue) | |||||
_rng = new Random(seed.Value); | |||||
else | |||||
_rng = null; | |||||
// Also clear the kernel cache, to reset any existing seeds | |||||
if (_handle != null) | |||||
c_api.TFE_ContextClearCaches(_handle); | |||||
} | |||||
public int? global_seed() | |||||
=> _seed; | |||||
public int? internal_operation_seed() | |||||
=> _rng?.Next(0, int.MaxValue); | |||||
public void start_step() | public void start_step() | ||||
=> c_api.TFE_ContextStartStep(_handle); | => c_api.TFE_ContextStartStep(_handle); | ||||
@@ -86,7 +107,7 @@ namespace Tensorflow.Contexts | |||||
{ | { | ||||
if(context_switches.Count() == 0) | if(context_switches.Count() == 0) | ||||
tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
return context_switches.Current().EagerMode; | return context_switches.Current().EagerMode; | ||||
} | } | ||||
@@ -14,16 +14,43 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class random_seed | public class random_seed | ||||
{ | { | ||||
private static int DEFAULT_GRAPH_SEED = 87654321; | private static int DEFAULT_GRAPH_SEED = 87654321; | ||||
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); | |||||
public static (int?, int?) get_seed(int? op_seed = null) | public static (int?, int?) get_seed(int? op_seed = null) | ||||
{ | { | ||||
int? global_seed; | |||||
if (tf.executing_eagerly()) | |||||
global_seed = tf.Context.global_seed(); | |||||
else | |||||
global_seed = ops.get_default_graph().seed; | |||||
if (global_seed.HasValue) | |||||
{ | |||||
if (!op_seed.HasValue) | |||||
if (tf.executing_eagerly()) | |||||
op_seed = tf.Context.internal_operation_seed(); | |||||
else | |||||
{ | |||||
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) | |||||
seed = 0; | |||||
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; | |||||
op_seed = seed; | |||||
} | |||||
return (global_seed, op_seed); | |||||
} | |||||
if (op_seed.HasValue) | if (op_seed.HasValue) | ||||
return (DEFAULT_GRAPH_SEED, 0); | |||||
return (DEFAULT_GRAPH_SEED, op_seed); | |||||
else | else | ||||
return (null, null); | return (null, null); | ||||
} | } | ||||
@@ -79,7 +79,8 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | ||||
string name = null) | string name = null) | ||||
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value, seed, seed2)); | |||||
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value) | |||||
.SetAttributes(new { seed = seed, seed2 = seed2 })); | |||||
/// <summary> | /// <summary> | ||||
/// Outputs random values from a truncated normal distribution. | /// Outputs random values from a truncated normal distribution. | ||||
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// Test the function of setting random seed | /// Test the function of setting random seed | ||||
/// This will help regenerate the same result | /// This will help regenerate the same result | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomSeedTest() | public void TFRandomSeedTest() | ||||
{ | { | ||||
var initValue = np.arange(6).reshape(3, 2); | var initValue = np.arange(6).reshape(3, 2); | ||||
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
/// <summary> | /// <summary> | ||||
/// This part we use funcs in tf.random rather than only tf | /// This part we use funcs in tf.random rather than only tf | ||||
/// </summary> | /// </summary> | ||||
[TestMethod, Ignore] | |||||
[TestMethod] | |||||
public void TFRandomRaodomSeedTest() | public void TFRandomRaodomSeedTest() | ||||
{ | { | ||||
tf.set_random_seed(1234); | tf.set_random_seed(1234); | ||||
@@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
var cardinality = dataset.dataset_cardinality(); | var cardinality = dataset.dataset_cardinality(); | ||||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | ||||
} | } | ||||
[TestMethod] | |||||
public void Shuffle() | |||||
{ | |||||
tf.set_random_seed(1234); | |||||
var dataset = tf.data.Dataset.range(3); | |||||
var shuffled = dataset.shuffle(3); | |||||
var zipped = tf.data.Dataset.zip(dataset, shuffled); | |||||
bool allEqual = true; | |||||
foreach (var item in zipped) | |||||
{ | |||||
if (item.Item1 != item.Item2) | |||||
allEqual = false; | |||||
} | |||||
Assert.IsFalse(allEqual); | |||||
} | |||||
} | } | ||||
} | } |