|
|
@@ -14,16 +14,43 @@ |
|
|
|
limitations under the License. |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using System.Collections.Generic; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
{ |
|
|
|
public class random_seed |
|
|
|
{ |
|
|
|
private static int DEFAULT_GRAPH_SEED = 87654321; |
|
|
|
private static Dictionary<string, int> _graph_to_seed_dict = new Dictionary<string, int>(); |
|
|
|
|
|
|
|
public static (int?, int?) get_seed(int? op_seed = null) |
|
|
|
{ |
|
|
|
int? global_seed; |
|
|
|
|
|
|
|
if (tf.executing_eagerly()) |
|
|
|
global_seed = tf.Context.global_seed(); |
|
|
|
else |
|
|
|
global_seed = ops.get_default_graph().seed; |
|
|
|
|
|
|
|
if (global_seed.HasValue) |
|
|
|
{ |
|
|
|
if (!op_seed.HasValue) |
|
|
|
if (tf.executing_eagerly()) |
|
|
|
op_seed = tf.Context.internal_operation_seed(); |
|
|
|
else |
|
|
|
{ |
|
|
|
if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) |
|
|
|
seed = 0; |
|
|
|
_graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; |
|
|
|
op_seed = seed; |
|
|
|
} |
|
|
|
|
|
|
|
return (global_seed, op_seed); |
|
|
|
} |
|
|
|
|
|
|
|
if (op_seed.HasValue) |
|
|
|
return (DEFAULT_GRAPH_SEED, 0); |
|
|
|
return (DEFAULT_GRAPH_SEED, op_seed); |
|
|
|
else |
|
|
|
return (null, null); |
|
|
|
} |
|
|
|