Browse Source

fix is_build_function in Layer.

tags/v0.30
Oceania2018 4 years ago
parent
commit
2ea8f2eb3f
4 changed files with 24 additions and 17 deletions
  1. +8
    -10
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +5
    -4
      src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
  4. +10
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 8
- 10
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Contexts
{
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
status.Check(true);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
initialized = true;
}

@@ -70,21 +70,19 @@ namespace Tensorflow.Contexts
public bool executing_eagerly()
=> context_switches.Current().EagerMode;

public bool is_build_function()
=> context_switches.Current().IsBuildingFunction;

public string shared_name(string name = null)
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ?
name :
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";

public void graph_mode()
=> mode(false);

public void eager_mode()
=> mode(true);
public void graph_mode(bool isFunc = false)
=> context_switches.Push(false, isFunc);

void mode(bool isEager)
{
context_switches.Push(isEager);
}
public void eager_mode(bool isFunc = false)
=> context_switches.Push(true, isFunc);

public void restore_mode()
{


+ 5
- 4
src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs View File

@@ -25,17 +25,18 @@ namespace Tensorflow.Contexts
{
Stack<ContextSwitch> stack;

public ContextSwitchStack(bool isEager)
public ContextSwitchStack(bool isEager, bool isFunc)
{
stack = new Stack<ContextSwitch>();
Push(isEager);
Push(isEager, isFunc);
}

public void Push(bool isEager)
public void Push(bool isEager, bool isFunc)
{
stack.Push(new ContextSwitch
{
EagerMode = isEager
EagerMode = isEager,
IsBuildingFunction = isFunc
});
}



+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
// using var graph = tf.keras.backend.get_graph().as_default();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();
tf.Context.graph_mode(isFunc: true);

tf_with(ops.name_scope(_name_scope()), scope =>
{


+ 10
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -176,9 +176,17 @@ namespace Tensorflow.Keras.Engine

tf.init_scope();

tf.Context.eager_mode();
bool need_restore_mode = false;
if (inputs.IsEagerTensor || tf.Context.is_build_function())
{
need_restore_mode = true;
tf.Context.eager_mode();
}
build(inputs);
tf.Context.restore_mode();

if (need_restore_mode)
tf.Context.restore_mode();

built = true;
}


Loading…
Cancel
Save