Browse Source

Fix graph instance in InputLayer.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
58d2daebfa
4 changed files with 11 additions and 30 deletions
  1. +3
    -10
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +4
    -6
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  4. +2
    -13
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs

+ 3
- 10
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -17,8 +17,8 @@ namespace Tensorflow.Graphs
IntPtr func_handle;
public string FuncName => _graph_key;

public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; }
public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; }

public Dictionary<long, (Tensor, Tensor)> _captures
@@ -175,14 +175,7 @@ namespace Tensorflow.Graphs
void add_capture(Tensor tensor, Tensor placeholder)
{
_captures.Add(tensor.Id, (tensor, placeholder));
if (Inputs == null)
Inputs = new Tensors(placeholder);
else
{
var inputs = Inputs.ToList();
inputs.Add(placeholder);
Inputs = new Tensors(inputs.ToArray());
}
Inputs.Add(placeholder);
}

Tensor _create_substitute_placeholder(Tensor value,


+ 2
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -39,7 +39,8 @@ namespace Tensorflow
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{
_graph = g ?? ops.get_default_graph();
_graph.as_default();
if (!_graph.building_function)
_graph.as_default();
_target = Encoding.UTF8.GetBytes(target);

using (var opts = new SessionOptions(target, config))


+ 4
- 6
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -58,9 +58,6 @@ namespace Tensorflow.Keras.Layers
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype;
}

// In graph mode, create a graph placeholder to call the layer on.
tf.Context.graph_mode();

if (args.InputTensor == null)
{
if (args.InputShape != null)
@@ -74,6 +71,9 @@ namespace Tensorflow.Keras.Layers
args.BatchInputShape = null;
}

var graph = keras.backend.get_graph();
graph.as_default();

args.InputTensor = keras.backend.placeholder(
shape: BatchInputShape,
dtype: DType,
@@ -81,8 +81,8 @@ namespace Tensorflow.Keras.Layers
sparse: args.Sparse,
ragged: args.Ragged);


isPlaceholder = true;
tf.Context.restore_mode();
}

// Create an input node to add to self.outbound_node
@@ -97,8 +97,6 @@ namespace Tensorflow.Keras.Layers
typeSpec = new TensorSpec(args.InputTensor.TensorShape,
dtype: args.InputTensor.dtype,
name: Name);

tf.Context.restore_mode();
}

public static InputLayer from_config(LayerArgs args)


+ 2
- 13
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -151,23 +151,12 @@ namespace Tensorflow.Keras.Utils

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
Layer op_layer = null;
/*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
Name = op.name
});*/
op_layer = op.type switch
{
// "AddV2" => keras.layers.Add(),
_ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
Name = op.name
})
};
});
created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
processed_ops.Add(op);


Loading…
Cancel
Save