From 58d2daebfa30fd8256d7562e9577d41ec64944fc Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 2 Jan 2021 22:19:00 -0600 Subject: [PATCH] Fix graph instance in InputLayer. --- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 13 +++---------- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 3 ++- src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs | 10 ++++------ src/TensorFlowNET.Keras/Utils/base_layer_utils.cs | 15 ++------------- 4 files changed, 11 insertions(+), 30 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 0d0ac581..b2c02dfe 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -17,8 +17,8 @@ namespace Tensorflow.Graphs IntPtr func_handle; public string FuncName => _graph_key; - public Tensors Inputs { get; set; } - public Tensors Outputs { get; set; } + public Tensors Inputs { get; set; } = new Tensors(); + public Tensors Outputs { get; set; } = new Tensors(); public Dictionary Attrs { get; set; } public Dictionary _captures @@ -175,14 +175,7 @@ namespace Tensorflow.Graphs void add_capture(Tensor tensor, Tensor placeholder) { _captures.Add(tensor.Id, (tensor, placeholder)); - if (Inputs == null) - Inputs = new Tensors(placeholder); - else - { - var inputs = Inputs.ToList(); - inputs.Add(placeholder); - Inputs = new Tensors(inputs.ToArray()); - } + Inputs.Add(placeholder); } Tensor _create_substitute_placeholder(Tensor value, diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 7a1e2b65..3c599f6c 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -39,7 +39,8 @@ namespace Tensorflow public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) { _graph = g ?? ops.get_default_graph(); - _graph.as_default(); + if (!_graph.building_function) + _graph.as_default(); _target = Encoding.UTF8.GetBytes(target); using (var opts = new SessionOptions(target, config)) diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 49814f42..54af955a 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -58,9 +58,6 @@ namespace Tensorflow.Keras.Layers args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; } - // In graph mode, create a graph placeholder to call the layer on. - tf.Context.graph_mode(); - if (args.InputTensor == null) { if (args.InputShape != null) @@ -74,6 +71,9 @@ namespace Tensorflow.Keras.Layers args.BatchInputShape = null; } + var graph = keras.backend.get_graph(); + graph.as_default(); + args.InputTensor = keras.backend.placeholder( shape: BatchInputShape, dtype: DType, @@ -81,8 +81,8 @@ namespace Tensorflow.Keras.Layers sparse: args.Sparse, ragged: args.Ragged); - isPlaceholder = true; + tf.Context.restore_mode(); } // Create an input node to add to self.outbound_node @@ -97,8 +97,6 @@ namespace Tensorflow.Keras.Layers typeSpec = new TensorSpec(args.InputTensor.TensorShape, dtype: args.InputTensor.dtype, name: Name); - - tf.Context.restore_mode(); } public static InputLayer from_config(LayerArgs args) diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 0510a25c..fe93e584 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -151,23 +151,12 @@ namespace Tensorflow.Keras.Utils // recursively CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - Layer op_layer = null; - /*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs + Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs { NodeDef = op.node_def, Constants = constants, Name = op.name - });*/ - op_layer = op.type switch - { - // "AddV2" => keras.layers.Add(), - _ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs - { - NodeDef = op.node_def, - Constants = constants, - Name = op.name - }) - }; + }); created_layers.Add(op_layer); op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); processed_ops.Add(op);