Browse Source

Fix graph instance in InputLayer.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
58d2daebfa
4 changed files with 11 additions and 30 deletions
  1. +3
    -10
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +4
    -6
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  4. +2
    -13
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs

+ 3
- 10
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -17,8 +17,8 @@ namespace Tensorflow.Graphs
IntPtr func_handle; IntPtr func_handle;
public string FuncName => _graph_key; public string FuncName => _graph_key;


public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; }
public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; } public Dictionary<string, string> Attrs { get; set; }


public Dictionary<long, (Tensor, Tensor)> _captures public Dictionary<long, (Tensor, Tensor)> _captures
@@ -175,14 +175,7 @@ namespace Tensorflow.Graphs
void add_capture(Tensor tensor, Tensor placeholder) void add_capture(Tensor tensor, Tensor placeholder)
{ {
_captures.Add(tensor.Id, (tensor, placeholder)); _captures.Add(tensor.Id, (tensor, placeholder));
if (Inputs == null)
Inputs = new Tensors(placeholder);
else
{
var inputs = Inputs.ToList();
inputs.Add(placeholder);
Inputs = new Tensors(inputs.ToArray());
}
Inputs.Add(placeholder);
} }


Tensor _create_substitute_placeholder(Tensor value, Tensor _create_substitute_placeholder(Tensor value,


+ 2
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -39,7 +39,8 @@ namespace Tensorflow
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{ {
_graph = g ?? ops.get_default_graph(); _graph = g ?? ops.get_default_graph();
_graph.as_default();
if (!_graph.building_function)
_graph.as_default();
_target = Encoding.UTF8.GetBytes(target); _target = Encoding.UTF8.GetBytes(target);


using (var opts = new SessionOptions(target, config)) using (var opts = new SessionOptions(target, config))


+ 4
- 6
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -58,9 +58,6 @@ namespace Tensorflow.Keras.Layers
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype;
} }


// In graph mode, create a graph placeholder to call the layer on.
tf.Context.graph_mode();

if (args.InputTensor == null) if (args.InputTensor == null)
{ {
if (args.InputShape != null) if (args.InputShape != null)
@@ -74,6 +71,9 @@ namespace Tensorflow.Keras.Layers
args.BatchInputShape = null; args.BatchInputShape = null;
} }


var graph = keras.backend.get_graph();
graph.as_default();

args.InputTensor = keras.backend.placeholder( args.InputTensor = keras.backend.placeholder(
shape: BatchInputShape, shape: BatchInputShape,
dtype: DType, dtype: DType,
@@ -81,8 +81,8 @@ namespace Tensorflow.Keras.Layers
sparse: args.Sparse, sparse: args.Sparse,
ragged: args.Ragged); ragged: args.Ragged);



isPlaceholder = true; isPlaceholder = true;
tf.Context.restore_mode();
} }


// Create an input node to add to self.outbound_node // Create an input node to add to self.outbound_node
@@ -97,8 +97,6 @@ namespace Tensorflow.Keras.Layers
typeSpec = new TensorSpec(args.InputTensor.TensorShape, typeSpec = new TensorSpec(args.InputTensor.TensorShape,
dtype: args.InputTensor.dtype, dtype: args.InputTensor.dtype,
name: Name); name: Name);

tf.Context.restore_mode();
} }


public static InputLayer from_config(LayerArgs args) public static InputLayer from_config(LayerArgs args)


+ 2
- 13
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -151,23 +151,12 @@ namespace Tensorflow.Keras.Utils


// recursively // recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
Layer op_layer = null;
/*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{ {
NodeDef = op.node_def, NodeDef = op.node_def,
Constants = constants, Constants = constants,
Name = op.name Name = op.name
});*/
op_layer = op.type switch
{
// "AddV2" => keras.layers.Add(),
_ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
Name = op.name
})
};
});
created_layers.Add(op_layer); created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
processed_ops.Add(op); processed_ops.Add(op);


Loading…
Cancel
Save