diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 3ba50a62..24e24a01 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Contexts { Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); status.Check(true); - context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE); + context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); initialized = true; } @@ -70,21 +70,19 @@ namespace Tensorflow.Contexts public bool executing_eagerly() => context_switches.Current().EagerMode; + public bool is_build_function() + => context_switches.Current().IsBuildingFunction; + public string shared_name(string name = null) => !string.IsNullOrEmpty(name) || !executing_eagerly() ? name : "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; - public void graph_mode() - => mode(false); - - public void eager_mode() - => mode(true); + public void graph_mode(bool isFunc = false) + => context_switches.Push(false, isFunc); - void mode(bool isEager) - { - context_switches.Push(isEager); - } + public void eager_mode(bool isFunc = false) + => context_switches.Push(true, isFunc); public void restore_mode() { diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs index e4011b68..84bc3889 100644 --- a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs +++ b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs @@ -25,17 +25,18 @@ namespace Tensorflow.Contexts { Stack stack; - public ContextSwitchStack(bool isEager) + public ContextSwitchStack(bool isEager, bool isFunc) { stack = new Stack(); - Push(isEager); + Push(isEager, isFunc); } - public void Push(bool isEager) + public void Push(bool isEager, bool isFunc) { stack.Push(new ContextSwitch { - EagerMode = isEager + EagerMode = isEager, + IsBuildingFunction = isFunc }); } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs index 391d47de..ba39e49f 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine // using var graph = tf.keras.backend.get_graph().as_default(); if (!inputs.IsEagerTensor) - tf.Context.graph_mode(); + tf.Context.graph_mode(isFunc: true); tf_with(ops.name_scope(_name_scope()), scope => { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 22fba034..fa072da9 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -176,9 +176,17 @@ namespace Tensorflow.Keras.Engine tf.init_scope(); - tf.Context.eager_mode(); + bool need_restore_mode = false; + if (inputs.IsEagerTensor || tf.Context.is_build_function()) + { + need_restore_mode = true; + tf.Context.eager_mode(); + } + build(inputs); - tf.Context.restore_mode(); + + if (need_restore_mode) + tf.Context.restore_mode(); built = true; }