Browse Source

Merge pull request #762 from Sebastian-roko/random_seed

Random seed handling more simular to python
tags/v0.40-tf2.4-tstring
Esther Hu GitHub 4 years ago
parent
commit
a22c52bba7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 80 additions and 6 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.random.cs
  2. +22
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +28
    -1
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  5. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs
  6. +20
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 6
- 1
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -93,7 +93,12 @@ namespace Tensorflow
=> random_ops.random_shuffle(value, seed: seed, name: name); => random_ops.random_shuffle(value, seed: seed, name: name);


public void set_random_seed(int seed) public void set_random_seed(int seed)
=> ops.get_default_graph().seed = seed;
{
if (executing_eagerly())
Context.set_global_seed(seed);
else
ops.get_default_graph().seed = seed;
}


public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)


+ 22
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,6 +42,9 @@ namespace Tensorflow.Contexts
SafeContextHandle _handle; SafeContextHandle _handle;
public SafeContextHandle Handle => _handle; public SafeContextHandle Handle => _handle;


int? _seed;
Random _rng;

public Context() public Context()
{ {
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
@@ -71,6 +74,24 @@ namespace Tensorflow.Contexts
initialized = true; initialized = true;
} }


public void set_global_seed(int? seed)
{
_seed = seed;
if (seed.HasValue)
_rng = new Random(seed.Value);
else
_rng = null;
// Also clear the kernel cache, to reset any existing seeds
if (_handle != null)
c_api.TFE_ContextClearCaches(_handle);
}

public int? global_seed()
=> _seed;

public int? internal_operation_seed()
=> _rng?.Next(0, int.MaxValue);

public void start_step() public void start_step()
=> c_api.TFE_ContextStartStep(_handle); => c_api.TFE_ContextStartStep(_handle);


@@ -86,7 +107,7 @@ namespace Tensorflow.Contexts
{ {
if(context_switches.Count() == 0) if(context_switches.Count() == 0)
tf.enable_eager_execution(); tf.enable_eager_execution();
return context_switches.Current().EagerMode; return context_switches.Current().EagerMode;
} }




+ 28
- 1
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -14,16 +14,43 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow namespace Tensorflow
{ {
public class random_seed public class random_seed
{ {
private static int DEFAULT_GRAPH_SEED = 87654321; private static int DEFAULT_GRAPH_SEED = 87654321;
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>();


public static (int?, int?) get_seed(int? op_seed = null) public static (int?, int?) get_seed(int? op_seed = null)
{ {
int? global_seed;

if (tf.executing_eagerly())
global_seed = tf.Context.global_seed();
else
global_seed = ops.get_default_graph().seed;

if (global_seed.HasValue)
{
if (!op_seed.HasValue)
if (tf.executing_eagerly())
op_seed = tf.Context.internal_operation_seed();
else
{
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed))
seed = 0;
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1;
op_seed = seed;
}

return (global_seed, op_seed);
}

if (op_seed.HasValue) if (op_seed.HasValue)
return (DEFAULT_GRAPH_SEED, 0);
return (DEFAULT_GRAPH_SEED, op_seed);
else else
return (null, null); return (null, null);
} }


+ 2
- 1
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -79,7 +79,8 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
string name = null) string name = null)
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value, seed, seed2));
=> tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value)
.SetAttributes(new { seed = seed, seed2 = seed2 }));


/// <summary> /// <summary>
/// Outputs random values from a truncated normal distribution. /// Outputs random values from a truncated normal distribution.


+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// Test the function of setting random seed /// Test the function of setting random seed
/// This will help regenerate the same result /// This will help regenerate the same result
/// </summary> /// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomSeedTest() public void TFRandomSeedTest()
{ {
var initValue = np.arange(6).reshape(3, 2); var initValue = np.arange(6).reshape(3, 2);
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary> /// <summary>
/// This part we use funcs in tf.random rather than only tf /// This part we use funcs in tf.random rather than only tf
/// </summary> /// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomRaodomSeedTest() public void TFRandomRaodomSeedTest()
{ {
tf.set_random_seed(1234); tf.set_random_seed(1234);


+ 20
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset
var cardinality = dataset.dataset_cardinality(); var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
} }

[TestMethod]
public void Shuffle()
{
tf.set_random_seed(1234);

var dataset = tf.data.Dataset.range(3);
var shuffled = dataset.shuffle(3);

var zipped = tf.data.Dataset.zip(dataset, shuffled);

bool allEqual = true;
foreach (var item in zipped)
{
if (item.Item1 != item.Item2)
allEqual = false;
}

Assert.IsFalse(allEqual);
}
} }
} }

Loading…
Cancel
Save