Browse Source

fix is_build_function in Layer.

tags/v0.30
Oceania2018 4 years ago
parent
commit
2ea8f2eb3f
4 changed files with 24 additions and 17 deletions
  1. +8
    -10
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +5
    -4
      src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
  4. +10
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 8
- 10
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Contexts
{ {
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
status.Check(true); status.Check(true);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
initialized = true; initialized = true;
} }


@@ -70,21 +70,19 @@ namespace Tensorflow.Contexts
public bool executing_eagerly() public bool executing_eagerly()
=> context_switches.Current().EagerMode; => context_switches.Current().EagerMode;


public bool is_build_function()
=> context_switches.Current().IsBuildingFunction;

public string shared_name(string name = null) public string shared_name(string name = null)
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ? => !string.IsNullOrEmpty(name) || !executing_eagerly() ?
name : name :
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; "cd2c89b7-88b7-44c8-ad83-06c2a9158347";


public void graph_mode()
=> mode(false);

public void eager_mode()
=> mode(true);
public void graph_mode(bool isFunc = false)
=> context_switches.Push(false, isFunc);


void mode(bool isEager)
{
context_switches.Push(isEager);
}
public void eager_mode(bool isFunc = false)
=> context_switches.Push(true, isFunc);


public void restore_mode() public void restore_mode()
{ {


+ 5
- 4
src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs View File

@@ -25,17 +25,18 @@ namespace Tensorflow.Contexts
{ {
Stack<ContextSwitch> stack; Stack<ContextSwitch> stack;


public ContextSwitchStack(bool isEager)
public ContextSwitchStack(bool isEager, bool isFunc)
{ {
stack = new Stack<ContextSwitch>(); stack = new Stack<ContextSwitch>();
Push(isEager);
Push(isEager, isFunc);
} }


public void Push(bool isEager)
public void Push(bool isEager, bool isFunc)
{ {
stack.Push(new ContextSwitch stack.Push(new ContextSwitch
{ {
EagerMode = isEager
EagerMode = isEager,
IsBuildingFunction = isFunc
}); });
} }




+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
// using var graph = tf.keras.backend.get_graph().as_default(); // using var graph = tf.keras.backend.get_graph().as_default();


if (!inputs.IsEagerTensor) if (!inputs.IsEagerTensor)
tf.Context.graph_mode();
tf.Context.graph_mode(isFunc: true);


tf_with(ops.name_scope(_name_scope()), scope => tf_with(ops.name_scope(_name_scope()), scope =>
{ {


+ 10
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -176,9 +176,17 @@ namespace Tensorflow.Keras.Engine


tf.init_scope(); tf.init_scope();


tf.Context.eager_mode();
bool need_restore_mode = false;
if (inputs.IsEagerTensor || tf.Context.is_build_function())
{
need_restore_mode = true;
tf.Context.eager_mode();
}
build(inputs); build(inputs);
tf.Context.restore_mode();

if (need_restore_mode)
tf.Context.restore_mode();


built = true; built = true;
} }


Loading…
Cancel
Save