|
|
@@ -150,6 +150,9 @@ namespace Tensorflow.Keras.Engine |
|
|
|
|
|
|
|
void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) |
|
|
|
{ |
|
|
|
if (_inferred_input_shape == input_shape) |
|
|
|
return; |
|
|
|
|
|
|
|
ops.init_scope(); |
|
|
|
var inputs = keras.Input(batch_input_shape: input_shape, |
|
|
|
dtype: input_dtype, |
|
|
@@ -157,16 +160,17 @@ namespace Tensorflow.Keras.Engine |
|
|
|
Tensors layer_input = inputs; |
|
|
|
Tensors layer_output = null; |
|
|
|
Tensors outputs = null; |
|
|
|
|
|
|
|
List<INode> created_nodes = new List<INode>(); |
|
|
|
foreach (var layer in _layers) |
|
|
|
{ |
|
|
|
clear_previously_created_nodes(layer, _created_nodes); |
|
|
|
layer_output = layer.Apply(layer_input); |
|
|
|
// Keep track of nodes just created above |
|
|
|
track_nodes_created_by_last_call(layer, _created_nodes); |
|
|
|
track_nodes_created_by_last_call(layer, created_nodes); |
|
|
|
layer_input = layer_output; |
|
|
|
outputs = layer_output; |
|
|
|
} |
|
|
|
_created_nodes = created_nodes; |
|
|
|
_init_graph_network(inputs, outputs); |
|
|
|
_graph_initialized = true; |
|
|
|
_inferred_input_shape = input_shape; |
|
|
@@ -174,16 +178,28 @@ namespace Tensorflow.Keras.Engine |
|
|
|
|
|
|
|
void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) |
|
|
|
{ |
|
|
|
foreach(var node in layer.InboundNodes) |
|
|
|
{ |
|
|
|
foreach(var prev_layer in node.InboundLayers) |
|
|
|
{ |
|
|
|
var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); |
|
|
|
prev_layer.OutboundNodes.Clear(); |
|
|
|
prev_layer.OutboundNodes.AddRange(outNodes); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); |
|
|
|
layer.InboundNodes.Clear(); |
|
|
|
layer.InboundNodes.AddRange(inNodes); |
|
|
|
} |
|
|
|
|
|
|
|
void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) |
|
|
|
{ |
|
|
|
var node = layer.InboundNodes.Last(); |
|
|
|
created_nodes.Add(node); |
|
|
|
foreach(var prev_layer in node.iterate_inbound()) |
|
|
|
foreach(var prev_layer in node.InboundLayers) |
|
|
|
{ |
|
|
|
created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); |
|
|
|
created_nodes.add(prev_layer.OutboundNodes.Last()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|