Browse Source

ContextSwitchStack

tags/v0.20
Oceania2018 5 years ago
parent
commit
32a0f582bf
11 changed files with 57 additions and 31 deletions
  1. +10
    -7
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Contexts/ContextSwitch.cs
  3. +19
    -5
      src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
  4. +0
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  6. +6
    -2
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  7. +12
    -9
      src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
  8. +1
    -1
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  9. +2
    -1
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
  10. +2
    -2
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
  11. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

+ 10
- 7
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -31,8 +31,7 @@ namespace Tensorflow.Contexts
public string DeviceName { get; set; } = "";
public string ScopeName { get; set; } = "";
bool initialized = false;
bool isEager;
ContextSwitchStack contextSwitches;
ContextSwitchStack context_switches;

public SafeContextHandle Handle { get; }

@@ -40,8 +39,7 @@ namespace Tensorflow.Contexts
{
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
status.Check(true);
isEager = defaultExecutionMode == EAGER_MODE;
contextSwitches = new ContextSwitchStack(isEager);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE);
initialized = true;
}

@@ -66,7 +64,7 @@ namespace Tensorflow.Contexts
/// </summary>
/// <returns></returns>
public bool executing_eagerly()
=> isEager;
=> context_switches.Current().EagerMode;

public string shared_name(string name = null)
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ?
@@ -79,9 +77,14 @@ namespace Tensorflow.Contexts
public void eager_mode()
=> mode(true);

void mode(bool mode)
void mode(bool isEager)
{
isEager = mode;
context_switches.Push(isEager);
}

public void restore_mode()
{
context_switches.Pop();
}

public void Dispose()


+ 2
- 0
src/TensorFlowNET.Core/Contexts/ContextSwitch.cs View File

@@ -22,6 +22,8 @@ namespace Tensorflow.Contexts
{
public class ContextSwitch
{
public bool EagerMode { get; set; }

/// <summary>
/// Whether the context is building a function.
/// </summary>


+ 19
- 5
src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs View File

@@ -30,11 +30,25 @@ namespace Tensorflow.Contexts
public ContextSwitchStack(bool isEager)
{
stack = new Stack<ContextSwitch>();
if (isEager)
stack.Push(new ContextSwitch
{
IsBuildingFunction = false
});
Push(isEager);
}

public void Push(bool isEager)
{
stack.Push(new ContextSwitch
{
EagerMode = isEager
});
}

public void Pop()
{
stack.Pop();
}

public ContextSwitch Current()
{
return stack.Peek();
}
}
}

+ 0
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -148,7 +148,6 @@ namespace Tensorflow
/// <returns></returns>
public Graph as_default()
{
tf.Context.graph_mode();
return ops.set_default_graph(this);
}

@@ -492,7 +491,6 @@ namespace Tensorflow

protected override void DisposeManagedResources()
{
tf.Context.eager_mode();
ops.default_graph_stack.remove(this);
}



+ 2
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -159,7 +159,8 @@ namespace Tensorflow.Keras.Engine
_set_mask_metadata(inputs, outputs, null);
});

tf.Context.eager_mode();
if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}


+ 6
- 2
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -34,10 +34,14 @@ namespace Tensorflow.Keras.Layers
IInitializer embeddings_initializer;

public Embedding(EmbeddingArgs args)
: base(args)
: base(new LayerArgs // copy args
{
DType = args.DType,
Name = args.Name
})
{
this.args = args;
if(args.InputShape == null)
if (args.InputShape == null)
args.InputShape = args.InputLength;

embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer;


+ 12
- 9
src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Layers
{
this.args = args;
built = true;
this.SupportsMasking = true;
SupportsMasking = true;

if(BatchInputShape != null)
{
@@ -58,6 +58,9 @@ namespace Tensorflow.Keras.Layers
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype;
}

// In graph mode, create a graph placeholder to call the layer on.
tf.Context.graph_mode();

if (args.InputTensor == null)
{
if(args.InputShape != null)
@@ -71,15 +74,13 @@ namespace Tensorflow.Keras.Layers
args.BatchInputShape = null;
}

// In graph mode, create a graph placeholder to call the layer on.
tf.Context.graph_mode();
args.InputTensor = tf.keras.backend.placeholder(
shape: BatchInputShape,
dtype: DType,
name: Name,
sparse: args.Sparse,
ragged: args.Ragged);
tf.Context.eager_mode();
shape: BatchInputShape,
dtype: DType,
name: Name,
sparse: args.Sparse,
ragged: args.Ragged);

isPlaceholder = true;
}
@@ -97,6 +98,8 @@ namespace Tensorflow.Keras.Layers
typeSpec = new TensorSpec(args.InputTensor.TensorShape,
dtype: args.InputTensor.dtype,
name: Name);

tf.Context.restore_mode();
}
}
}

+ 1
- 1
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -29,7 +29,7 @@

<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.2.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
</ItemGroup>

<ItemGroup>


+ 2
- 1
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -12,7 +12,8 @@ namespace Tensorflow.UnitTest
[TestInitialize]
public void TestInit()
{
tf.enable_eager_execution();
if (!tf.executing_eagerly())
tf.enable_eager_execution();
}

[TestCleanup]


+ 2
- 2
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Keras
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
/// </summary>
[TestClass]
public class LayersTest : GraphModeTestBase
public class LayersTest : EagerModeTestBase
{
[TestMethod]
public void Sequential()
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.Keras
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary>
[TestMethod, Ignore]
[TestMethod]
public void Embedding()
{
var model = new Sequential();


+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -43,7 +43,7 @@

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" />
<PackageReference Include="NumSharp.Lite" Version="0.1.7" />


Loading…
Cancel
Save