Browse Source

Add global_seed to context

tags/v0.40-tf2.4-tstring
Sebastian Hantsch 4 years ago
parent
commit
e5b40b0d0a
3 changed files with 27 additions and 2 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.random.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +13
    -1
      src/TensorFlowNET.Core/Framework/random_seed.py.cs

+ 6
- 1
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -93,7 +93,12 @@ namespace Tensorflow
=> random_ops.random_shuffle(value, seed: seed, name: name);

public void set_random_seed(int seed)
=> ops.get_default_graph().seed = seed;
{
if (executing_eagerly())
Context.set_global_seed(seed);
else
ops.get_default_graph().seed = seed;
}

public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)


+ 8
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,6 +42,8 @@ namespace Tensorflow.Contexts
SafeContextHandle _handle;
public SafeContextHandle Handle => _handle;

int? _seed;

public Context()
{
_device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
@@ -71,6 +73,12 @@ namespace Tensorflow.Contexts
initialized = true;
}

public void set_global_seed(int? seed)
=> _seed = seed;

public int? global_seed()
=> _seed;

public void start_step()
=> c_api.TFE_ContextStartStep(_handle);



+ 13
- 1
src/TensorFlowNET.Core/Framework/random_seed.py.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using static Tensorflow.Binding;

namespace Tensorflow
{
public class random_seed
@@ -22,8 +24,18 @@ namespace Tensorflow

public static (int?, int?) get_seed(int? op_seed = null)
{
int? global_seed;

if (tf.executing_eagerly())
global_seed = tf.Context.global_seed();
else
global_seed = ops.get_default_graph().seed;

if (global_seed.HasValue)
return (global_seed, op_seed);

if (op_seed.HasValue)
return (DEFAULT_GRAPH_SEED, 0);
return (DEFAULT_GRAPH_SEED, op_seed);
else
return (null, null);
}


Loading…
Cancel
Save