diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 6f6657af..254f4ded 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -22,6 +22,7 @@ using System.ComponentModel; using System.Diagnostics; using System.Linq; using NumSharp.Utilities; +using System.Runtime.CompilerServices; namespace Tensorflow { @@ -50,7 +51,7 @@ namespace Tensorflow => list.Add(element); public static void append(this IList list, T element) - => list.Add(element); + => list.Insert(list.Count, element); public static T[] concat(this IList list1, IList list2) { @@ -407,5 +408,37 @@ namespace Tensorflow return true; return false; } + + public static bool issubset(this IEnumerable subset, IEnumerable src) + { + bool issubset = true; + foreach (var element in subset) + { + if (!src.Contains(element)) + { + issubset = false; + continue; + } + } + + return true; + } + + public static TValue SetDefault(this Dictionary dic, TKey key, TValue value) + { + if (dic.ContainsKey(key)) + return dic[key]; + + dic[key] = value; + return value; + } + + public static TValue Get(this Dictionary dic, TKey key, TValue value) + { + if (dic.ContainsKey(key)) + return dic[key]; + + return value; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs index c1163e15..d8197285 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { if (_channels_first) { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index af8e9114..e990fb86 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { @@ -21,6 +22,11 @@ namespace Tensorflow.Keras.Engine List _input_layers; List _input_coordinates; List _output_coordinates; + public string[] NetworkNodes { get; set; } + public Dictionary> NodesByDepth { get; set; } + public List Layers { get; set; } + Dictionary tensor_usage_count; + public Dictionary TensorUsageCount => tensor_usage_count; public Functional(Tensors inputs, Tensors outputs) : base(new ModelArgs @@ -33,6 +39,7 @@ namespace Tensorflow.Keras.Engine _output_layers = new List(); _input_coordinates = new List(); _output_coordinates = new List(); + tensor_usage_count = new Dictionary(); _init_graph_network(inputs, outputs); } @@ -67,16 +74,253 @@ namespace Tensorflow.Keras.Engine _input_layers.append(layer); _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); } + + // Keep track of the network's nodes and layers. + var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); + + NetworkNodes = nodes; + NodesByDepth = nodes_by_depth; + Layers = layers; + + ComputeTensorUsageCount(); } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + void ComputeTensorUsageCount() { - return run_internal_graph(inputs, state, is_training); + var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); + var depth_keys = NodesByDepth.Keys.Reverse().Skip(1).ToArray(); + foreach(var depth in depth_keys) + { + foreach(var node in NodesByDepth[depth]) + { + var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); + if (input_tensors.issubset(available_tensors)) + { + foreach (var tensor in node.KerasInputs) + { + if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) + tensor_usage_count[tensor.GetHashCode()] = 0; + tensor_usage_count[tensor.GetHashCode()] += 1; + } + + foreach (var output_tensor in node.Outputs) + available_tensors.Add(output_tensor.GetHashCode()); + } + } + } + + foreach (var tensor in outputs) + { + if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) + tensor_usage_count[tensor.GetHashCode()] = 0; + tensor_usage_count[tensor.GetHashCode()] += 1; + } + } + + /// + /// Validates a network's topology and gather its layers and nodes. + /// + /// + /// + (string[], Dictionary>, List, Dictionary>) MapGraphNetwork(Tensors inputs, Tensors outputs) + { + var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); + var network_nodes = nodes_in_decreasing_depth + .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) + .ToArray(); + + var nodes_depths = new Dictionary(); + var layers_depths = new Dictionary(); + + nodes_in_decreasing_depth.Reverse(); + foreach (var node in nodes_in_decreasing_depth) + { + // If the depth is not set, the node has no outbound nodes (depth 0). + int depth = nodes_depths.SetDefault(node, 0); + // Update the depth of the corresponding layer + int previous_depth = layers_depths.Get(node.Layer, 0); + // If we've seen this layer before at a higher depth, + // we should use that depth instead of the node depth. + // This is necessary for shared layers that have inputs at different + // depth levels in the graph. + depth = Math.Max(depth, previous_depth); + layers_depths[node.Layer] = depth; + nodes_depths[node] = depth; + + // Update the depth of inbound nodes. + // The "depth" of a node is the max of the depths + // of all nodes it is connected to + 1. + foreach(var node_dep in node.ParentNodes) + { + previous_depth = nodes_depths.Get(node_dep, 0); + nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); + } + } + + // Handle inputs that are not connected to outputs. + // We do not error out here because the inputs may be used to compute losses + // and metrics. + foreach(var input_t in inputs) + { + var (input_layer, _, _) = input_t.KerasHistory; + if (!layers_depths.ContainsKey(input_layer)) + { + layers_depths[input_layer] = 0; + layer_indices[input_layer] = -1; + nodes_depths[input_layer.InboundNodes[0]] = 0; + network_nodes.add(MakeNodeKey(input_layer.Name, 0)); + } + } + + // Build a dict {depth: list of nodes with this depth} + var nodes_by_depth = new Dictionary>(); + foreach (var node in nodes_depths) + { + if (!nodes_by_depth.ContainsKey(node.Value)) + nodes_by_depth[node.Value] = new List(); + nodes_by_depth[node.Value].append(node.Key); + } + + var layers_by_depth = new Dictionary>(); + foreach (var layer in layers_depths) + { + if (!layers_by_depth.ContainsKey(layer.Value)) + layers_by_depth[layer.Value] = new List(); + layers_by_depth[layer.Value].append(layer.Key); + } + + // Get sorted list of layer depths. + var depth_keys = layers_by_depth.Keys.Reverse(); + + // Set self.layers ordered by depth. + var layers = new List(); + foreach(var depth in depth_keys) + { + var layers_for_depth = layers_by_depth[depth]; + + // Network.layers needs to have a deterministic order: + // here we order them by traversal order. + layers_for_depth.Reverse(); + layers.AddRange(layers_for_depth); + } + + // Get sorted list of node depths. + depth_keys = nodes_by_depth.Keys.Reverse(); + + return (network_nodes, nodes_by_depth, layers, layers_by_depth); } - Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false) + string MakeNodeKey(string layer_name, int node_index) + => $"{layer_name}_ib-{node_index}"; + + /// + /// This method topologically sorts nodes in order from inputs to outputs. + /// + /// + (List, Dictionary) BuildMap(Tensors outputs) { + var finished_nodes = new List(); + var nodes_in_progress = new List(); + var nodes_in_decreasing_depth = new List(); + var layer_indices = new Dictionary(); + foreach (var output in outputs) + BuildMapHelper(output, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + layer_indices); + + return (nodes_in_decreasing_depth, layer_indices); + } + + void BuildMapHelper(Tensor tensor, + List finished_nodes, + List nodes_in_progress, + List nodes_in_decreasing_depth, + Dictionary layer_indices) + { + var (layer, node_index, _) = tensor.KerasHistory; + var node = layer.InboundNodes[node_index]; + + // Don't repeat work for shared subgraphs + if (finished_nodes.Contains(node)) + return; + + // Prevent cycles. + if (nodes_in_progress.Contains(node)) + throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); + + // Store the traversal order for layer sorting. + if (!layer_indices.ContainsKey(layer)) + layer_indices[layer] = layer_indices.Count; + + // Propagate to all previous tensors connected to this node. + nodes_in_progress.Add(node); + foreach (var k_tensor in node.KerasInputs) + BuildMapHelper(k_tensor, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + layer_indices); + + finished_nodes.Add(node); + nodes_in_progress.Remove(node); + nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); + } + + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) + { + return run_internal_graph(inputs, is_training); + } + + Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) + { + if (mask != null) + { + Tensor[] masks = new Tensor[inputs.Count()]; + foreach (var (i, input_t) in enumerate(inputs)) + input_t.KerasMask = masks[i]; + } + + var tensor_dict = new Dictionary(); + foreach (var (x, y) in zip(this.inputs, inputs)) + { + var y1 = conform_to_reference_input(y, x); + var x_id = x.GetHashCode(); + tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray(); + } + + var depth_keys = NodesByDepth.Keys.Reverse().ToArray(); + + foreach(var depth in depth_keys) + { + var nodes = NodesByDepth[depth]; + foreach(var node in nodes) + { + // Input tensors already exist. + if (node.IsInput) + continue; + + var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]); + tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; + + var outputs = node.Layer.Apply(layer_inputs, is_training: training); + // Update tensor_dict. + foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) + tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); + } + } + + foreach(var x in outputs) + { + + } throw new NotImplementedException(""); } + + Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) + { + return tensor; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index 05eb6fa7..bdee38ce 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -9,10 +9,10 @@ namespace Tensorflow.Keras.Engine /// public class KerasHistory { - public Layer layer; + Layer layer; int node_index; int tensor_index; - public Tensor tensor; + Tensor tensor; public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs index 9513b26c..a2b6ef2d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - outputs = call_fn(inputs, state: state, is_training: is_training); + outputs = CallFn(inputs, state: state, is_training: is_training); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs index 14d61c31..a32952cb 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine if (!dynamic) throw new NotImplementedException(""); - outputs = call_fn(inputs); + outputs = CallFn(inputs); outputs = _set_connectivity_metadata_(inputs, outputs); _handle_activity_regularization(inputs, outputs); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index c88a2263..d64f0d1c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 0ae84ac8..72d560c7 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -39,20 +39,42 @@ namespace Tensorflow.Keras.Engine public Tensors Outputs => args.Outputs; public TensorShape[] input_shapes; public TensorShape[] output_shapes; - List kerasInputs = new List(); + public List KerasInputs = new List(); + public Layer Layer { get; set; } + public bool IsInput => args.InputTensors == null; + public int[] FlatInputIds { get; set; } + public int[] FlatOutputIds { get; set; } + + public Node[] ParentNodes + { + get + { + var node_deps = new List(); + foreach(var kt in KerasInputs) + { + var (layer, node_index, _) = kt.KerasHistory; + if (layer != null) + node_deps.append(layer.InboundNodes[node_index]); + } + return node_deps.ToArray(); + } + } public Node(Layer layer, NodeArgs args) { this.args = args; + this.Layer = layer; if (args.InputTensors != null) - kerasInputs.AddRange(args.InputTensors); + KerasInputs.AddRange(args.InputTensors); // Wire up Node to Layers. layer.InboundNodes.Add(this); - foreach (var kt in kerasInputs) + foreach (var kt in KerasInputs) { - var inbound_layer = kt.KerasHistory.layer; + if (kt.KerasHistory == null) + continue; + var (inbound_layer, _, _) = kt.KerasHistory; if (inbound_layer != null) inbound_layer.OutboundNodes.Add(this); } @@ -61,6 +83,10 @@ namespace Tensorflow.Keras.Engine var node_index = layer.InboundNodes.Count - 1; foreach (var (i, tensor) in enumerate(Outputs)) tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); + + // Cached for performance. + FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray(); + FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray(); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs index 70a1458f..423287e0 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs @@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { - return base.call_fn(inputs, state, is_training); + return base.CallFn(inputs, state, is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 0f855915..fc32a792 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { Tensor outputs = null; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs index 7b358d75..d8b4bad9 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs @@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) { var outputs = _convolution_op.Apply(inputs, kernel); if (use_bias) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 2bdda94c..7eed3a63 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs index cd53a7a2..ec4cebae 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { var output = tf_utils.smart_cond(is_training, () => tf.nn.dropout(inputs, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index fddafbc9..ef85d8a4 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs index 2ef11ed8..266081c0 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LSTM.cs @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers .ToArray(); } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { - return base.call_fn(inputs, state: state, is_training: is_training); + return base.CallFn(inputs, state: state, is_training: is_training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 1cccb598..a099caf2 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs index 112f427e..b542bcbd 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { scale = math_ops.cast(args.Scale, args.DType); offset = math_ops.cast(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs index 5b738426..2790857b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { return tf.keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index 84293b72..7a8b4311 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -74,7 +74,7 @@ namespace Tensorflow /// /// /// - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index 23de35dc..f1f49792 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,7 +67,7 @@ namespace Tensorflow built = true; } - protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) + protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 7d4f57d9..092756db 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -145,6 +145,7 @@ namespace Tensorflow /// Keras History: (Layer, (node_index, tensor_index)) /// public KerasHistory KerasHistory { get; set; } + public Tensor KerasMask { get; set; } /// /// Updates the shape of this tensor.