@@ -42,7 +42,7 @@ namespace Tensorflow.Contexts | |||||
{ | { | ||||
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | ||||
status.Check(true); | status.Check(true); | ||||
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE); | |||||
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | |||||
initialized = true; | initialized = true; | ||||
} | } | ||||
@@ -70,21 +70,19 @@ namespace Tensorflow.Contexts | |||||
public bool executing_eagerly() | public bool executing_eagerly() | ||||
=> context_switches.Current().EagerMode; | => context_switches.Current().EagerMode; | ||||
public bool is_build_function() | |||||
=> context_switches.Current().IsBuildingFunction; | |||||
public string shared_name(string name = null) | public string shared_name(string name = null) | ||||
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ? | => !string.IsNullOrEmpty(name) || !executing_eagerly() ? | ||||
name : | name : | ||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | ||||
public void graph_mode() | |||||
=> mode(false); | |||||
public void eager_mode() | |||||
=> mode(true); | |||||
public void graph_mode(bool isFunc = false) | |||||
=> context_switches.Push(false, isFunc); | |||||
void mode(bool isEager) | |||||
{ | |||||
context_switches.Push(isEager); | |||||
} | |||||
public void eager_mode(bool isFunc = false) | |||||
=> context_switches.Push(true, isFunc); | |||||
public void restore_mode() | public void restore_mode() | ||||
{ | { | ||||
@@ -25,17 +25,18 @@ namespace Tensorflow.Contexts | |||||
{ | { | ||||
Stack<ContextSwitch> stack; | Stack<ContextSwitch> stack; | ||||
public ContextSwitchStack(bool isEager) | |||||
public ContextSwitchStack(bool isEager, bool isFunc) | |||||
{ | { | ||||
stack = new Stack<ContextSwitch>(); | stack = new Stack<ContextSwitch>(); | ||||
Push(isEager); | |||||
Push(isEager, isFunc); | |||||
} | } | ||||
public void Push(bool isEager) | |||||
public void Push(bool isEager, bool isFunc) | |||||
{ | { | ||||
stack.Push(new ContextSwitch | stack.Push(new ContextSwitch | ||||
{ | { | ||||
EagerMode = isEager | |||||
EagerMode = isEager, | |||||
IsBuildingFunction = isFunc | |||||
}); | }); | ||||
} | } | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine | |||||
// using var graph = tf.keras.backend.get_graph().as_default(); | // using var graph = tf.keras.backend.get_graph().as_default(); | ||||
if (!inputs.IsEagerTensor) | if (!inputs.IsEagerTensor) | ||||
tf.Context.graph_mode(); | |||||
tf.Context.graph_mode(isFunc: true); | |||||
tf_with(ops.name_scope(_name_scope()), scope => | tf_with(ops.name_scope(_name_scope()), scope => | ||||
{ | { | ||||
@@ -176,9 +176,17 @@ namespace Tensorflow.Keras.Engine | |||||
tf.init_scope(); | tf.init_scope(); | ||||
tf.Context.eager_mode(); | |||||
bool need_restore_mode = false; | |||||
if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||||
{ | |||||
need_restore_mode = true; | |||||
tf.Context.eager_mode(); | |||||
} | |||||
build(inputs); | build(inputs); | ||||
tf.Context.restore_mode(); | |||||
if (need_restore_mode) | |||||
tf.Context.restore_mode(); | |||||
built = true; | built = true; | ||||
} | } | ||||