Browse Source

Pass named args to RandomShuffle; seed handling more simular to python

tags/v0.40-tf2.4-tstring
Sebastian Hantsch 4 years ago
parent
commit
d3e85fe84e
4 changed files with 37 additions and 7 deletions
  1. +15
    -2
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +15
    -0
      src/TensorFlowNET.Core/Framework/random_seed.py.cs
  3. +3
    -1
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  4. +4
    -4
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs

+ 15
- 2
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -43,6 +43,7 @@ namespace Tensorflow.Contexts
public SafeContextHandle Handle => _handle;

int? _seed;
Random _rng;

public Context()
{
@@ -74,11 +75,23 @@ namespace Tensorflow.Contexts
}

public void set_global_seed(int? seed)
=> _seed = seed;
{
_seed = seed;
if (seed.HasValue)
_rng = new Random(seed.Value);
else
_rng = null;
// Also clear the kernel cache, to reset any existing seeds
if (_handle != null)
c_api.TFE_ContextClearCaches(_handle);
}

public int? global_seed()
=> _seed;

public int? internal_operation_seed()
=> _rng?.Next(0, int.MaxValue);

public void start_step()
=> c_api.TFE_ContextStartStep(_handle);

@@ -94,7 +107,7 @@ namespace Tensorflow.Contexts
{
if(context_switches.Count() == 0)
tf.enable_eager_execution();
return context_switches.Current().EagerMode;
}



+ 15
- 0
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -21,6 +22,7 @@ namespace Tensorflow
public class random_seed
{
private static int DEFAULT_GRAPH_SEED = 87654321;
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>();

public static (int?, int?) get_seed(int? op_seed = null)
{
@@ -32,7 +34,20 @@ namespace Tensorflow
global_seed = ops.get_default_graph().seed;

if (global_seed.HasValue)
{
if (!op_seed.HasValue)
if (tf.executing_eagerly())
op_seed = tf.Context.internal_operation_seed();
else
{
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed))
seed = 0;
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1;
op_seed = seed;
}

return (global_seed, op_seed);
}

if (op_seed.HasValue)
return (DEFAULT_GRAPH_SEED, op_seed);


+ 3
- 1
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -131,7 +131,9 @@ namespace Tensorflow
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomShuffle", name,
null,
value, seed, seed2);
value,
"seed", seed,
"seed2", seed2);

return results[0];
}


+ 4
- 4
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// Test the function of setting random seed
/// This will help regenerate the same result
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomSeedTest()
{
var initValue = np.arange(6).reshape(3, 2);
@@ -37,7 +37,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary>
/// compare to Test above, seed is also added in params
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomSeedTest2()
{
var initValue = np.arange(6).reshape(3, 2);
@@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary>
/// This part we use funcs in tf.random rather than only tf
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomRaodomSeedTest()
{
tf.set_random_seed(1234);
@@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest.Basics
/// <summary>
/// compare to Test above, seed is also added in params
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void TFRandomRaodomSeedTest2()
{
tf.set_random_seed(1234);


Loading…
Cancel
Save