Browse Source

Merge pull request #762 from Sebastian-roko/random_seed

Random seed handling more simular to python
tags/v0.40-tf2.4-tstring
Esther Hu GitHub 4 years ago
parent
commit
a22c52bba7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 80 additions and 6 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.random.cs
  2. +22
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +28
    -1
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  5. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs
  6. +20
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 6
- 1
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -93,7 +93,12 @@ namespace Tensorflow
=> random_ops.random_shuffle(value, seed: seed, name: name);

public void set_random_seed(int seed)
=> ops.get_default_graph().seed = seed;
{
if (executing_eagerly())
Context.set_global_seed(seed);
else
ops.get_default_graph().seed = seed;
}

public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)


+ 22
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,6 +42,9 @@ namespace Tensorflow.Contexts
SafeContextHandle _handle;
public SafeContextHandle Handle => _handle;

int? _seed;
Random _rng;

public Context()
{
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
@@ -71,6 +74,24 @@ namespace Tensorflow.Contexts
initialized = true;
}

public void set_global_seed(int? seed)
{
_seed = seed;
if (seed.HasValue)
_rng = new Random(seed.Value);
else
_rng = null;
// Also clear the kernel cache, to reset any existing seeds
if (_handle != null)
c_api.TFE_ContextClearCaches(_handle);
}

public int? global_seed()
=> _seed;

public int? internal_operation_seed()
=> _rng?.Next(0, int.MaxValue);

public void start_step()
=> c_api.TFE_ContextStartStep(_handle);

@@ -86,7 +107,7 @@ namespace Tensorflow.Contexts
{
if(context_switches.Count() == 0)
tf.enable_eager_execution();
return context_switches.Current().EagerMode;
}



+ 28
- 1
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -14,16 +14,43 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class random_seed
{
private static int DEFAULT_GRAPH_SEED = 87654321;
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>();

public static (int?, int?) get_seed(int? op_seed = null)
{
int? global_seed;

if (tf.executing_eagerly())
global_seed = tf.Context.global_seed();
else
global_seed = ops.get_default_graph().seed;

if (global_seed.HasValue)
{
if (!op_seed.HasValue)
if (tf.executing_eagerly())
op_seed = tf.Context.internal_operation_seed();
else
{
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed))
seed = 0;
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1;
op_seed = seed;
}

return (global_seed, op_seed);
}

if (op_seed.HasValue)
return (DEFAULT_GRAPH_SEED, 0);
return (DEFAULT_GRAPH_SEED, op_seed);
else
return (null, null);
}


+ 2
- 1
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -79,7 +79,8 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
string name = null)
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value, seed, seed2));
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value)
.SetAttributes(new { seed = seed, seed2 = seed2 }));

/// <summary>
/// Outputs random values from a truncated normal distribution.


+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// Test the function of setting random seed
/// This will help regenerate the same result
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomSeedTest()
{
var initValue = np.arange(6).reshape(3, 2);
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary>
/// This part we use funcs in tf.random rather than only tf
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomRaodomSeedTest()
{
tf.set_random_seed(1234);


+ 20
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset
var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

[TestMethod]
public void Shuffle()
{
tf.set_random_seed(1234);

var dataset = tf.data.Dataset.range(3);
var shuffled = dataset.shuffle(3);

var zipped = tf.data.Dataset.zip(dataset, shuffled);

bool allEqual = true;
foreach (var item in zipped)
{
if (item.Item1 != item.Item2)
allEqual = false;
}

Assert.IsFalse(allEqual);
}
}
}

Loading…
Cancel
Save