@@ -17,8 +17,8 @@ namespace Tensorflow.Graphs | |||||
IntPtr func_handle; | IntPtr func_handle; | ||||
public string FuncName => _graph_key; | public string FuncName => _graph_key; | ||||
public Tensors Inputs { get; set; } | |||||
public Tensors Outputs { get; set; } | |||||
public Tensors Inputs { get; set; } = new Tensors(); | |||||
public Tensors Outputs { get; set; } = new Tensors(); | |||||
public Dictionary<string, string> Attrs { get; set; } | public Dictionary<string, string> Attrs { get; set; } | ||||
public Dictionary<long, (Tensor, Tensor)> _captures | public Dictionary<long, (Tensor, Tensor)> _captures | ||||
@@ -175,14 +175,7 @@ namespace Tensorflow.Graphs | |||||
void add_capture(Tensor tensor, Tensor placeholder) | void add_capture(Tensor tensor, Tensor placeholder) | ||||
{ | { | ||||
_captures.Add(tensor.Id, (tensor, placeholder)); | _captures.Add(tensor.Id, (tensor, placeholder)); | ||||
if (Inputs == null) | |||||
Inputs = new Tensors(placeholder); | |||||
else | |||||
{ | |||||
var inputs = Inputs.ToList(); | |||||
inputs.Add(placeholder); | |||||
Inputs = new Tensors(inputs.ToArray()); | |||||
} | |||||
Inputs.Add(placeholder); | |||||
} | } | ||||
Tensor _create_substitute_placeholder(Tensor value, | Tensor _create_substitute_placeholder(Tensor value, | ||||
@@ -39,7 +39,8 @@ namespace Tensorflow | |||||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | ||||
{ | { | ||||
_graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
_graph.as_default(); | |||||
if (!_graph.building_function) | |||||
_graph.as_default(); | |||||
_target = Encoding.UTF8.GetBytes(target); | _target = Encoding.UTF8.GetBytes(target); | ||||
using (var opts = new SessionOptions(target, config)) | using (var opts = new SessionOptions(target, config)) | ||||
@@ -58,9 +58,6 @@ namespace Tensorflow.Keras.Layers | |||||
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | ||||
} | } | ||||
// In graph mode, create a graph placeholder to call the layer on. | |||||
tf.Context.graph_mode(); | |||||
if (args.InputTensor == null) | if (args.InputTensor == null) | ||||
{ | { | ||||
if (args.InputShape != null) | if (args.InputShape != null) | ||||
@@ -74,6 +71,9 @@ namespace Tensorflow.Keras.Layers | |||||
args.BatchInputShape = null; | args.BatchInputShape = null; | ||||
} | } | ||||
var graph = keras.backend.get_graph(); | |||||
graph.as_default(); | |||||
args.InputTensor = keras.backend.placeholder( | args.InputTensor = keras.backend.placeholder( | ||||
shape: BatchInputShape, | shape: BatchInputShape, | ||||
dtype: DType, | dtype: DType, | ||||
@@ -81,8 +81,8 @@ namespace Tensorflow.Keras.Layers | |||||
sparse: args.Sparse, | sparse: args.Sparse, | ||||
ragged: args.Ragged); | ragged: args.Ragged); | ||||
isPlaceholder = true; | isPlaceholder = true; | ||||
tf.Context.restore_mode(); | |||||
} | } | ||||
// Create an input node to add to self.outbound_node | // Create an input node to add to self.outbound_node | ||||
@@ -97,8 +97,6 @@ namespace Tensorflow.Keras.Layers | |||||
typeSpec = new TensorSpec(args.InputTensor.TensorShape, | typeSpec = new TensorSpec(args.InputTensor.TensorShape, | ||||
dtype: args.InputTensor.dtype, | dtype: args.InputTensor.dtype, | ||||
name: Name); | name: Name); | ||||
tf.Context.restore_mode(); | |||||
} | } | ||||
public static InputLayer from_config(LayerArgs args) | public static InputLayer from_config(LayerArgs args) | ||||
@@ -151,23 +151,12 @@ namespace Tensorflow.Keras.Utils | |||||
// recursively | // recursively | ||||
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | ||||
Layer op_layer = null; | |||||
/*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
{ | { | ||||
NodeDef = op.node_def, | NodeDef = op.node_def, | ||||
Constants = constants, | Constants = constants, | ||||
Name = op.name | Name = op.name | ||||
});*/ | |||||
op_layer = op.type switch | |||||
{ | |||||
// "AddV2" => keras.layers.Add(), | |||||
_ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
{ | |||||
NodeDef = op.node_def, | |||||
Constants = constants, | |||||
Name = op.name | |||||
}) | |||||
}; | |||||
}); | |||||
created_layers.Add(op_layer); | created_layers.Add(op_layer); | ||||
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); | op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); | ||||
processed_ops.Add(op); | processed_ops.Add(op); | ||||