@@ -22,6 +22,7 @@ using System.ComponentModel; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
using System.Runtime.CompilerServices; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -50,7 +51,7 @@ namespace Tensorflow | |||||
=> list.Add(element); | => list.Add(element); | ||||
public static void append<T>(this IList<T> list, T element) | public static void append<T>(this IList<T> list, T element) | ||||
=> list.Add(element); | |||||
=> list.Insert(list.Count, element); | |||||
public static T[] concat<T>(this IList<T> list1, IList<T> list2) | public static T[] concat<T>(this IList<T> list1, IList<T> list2) | ||||
{ | { | ||||
@@ -407,5 +408,37 @@ namespace Tensorflow | |||||
return true; | return true; | ||||
return false; | return false; | ||||
} | } | ||||
public static bool issubset<T>(this IEnumerable<T> subset, IEnumerable<T> src) | |||||
{ | |||||
bool issubset = true; | |||||
foreach (var element in subset) | |||||
{ | |||||
if (!src.Contains(element)) | |||||
{ | |||||
issubset = false; | |||||
continue; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
{ | |||||
if (dic.ContainsKey(key)) | |||||
return dic[key]; | |||||
dic[key] = value; | |||||
return value; | |||||
} | |||||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
{ | |||||
if (dic.ContainsKey(key)) | |||||
return dic[key]; | |||||
return value; | |||||
} | |||||
} | } | ||||
} | } |
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
_channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
if (_channels_first) | if (_channels_first) | ||||
{ | { | ||||
@@ -4,6 +4,7 @@ using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -21,6 +22,11 @@ namespace Tensorflow.Keras.Engine | |||||
List<Layer> _input_layers; | List<Layer> _input_layers; | ||||
List<KerasHistory> _input_coordinates; | List<KerasHistory> _input_coordinates; | ||||
List<KerasHistory> _output_coordinates; | List<KerasHistory> _output_coordinates; | ||||
public string[] NetworkNodes { get; set; } | |||||
public Dictionary<int, List<Node>> NodesByDepth { get; set; } | |||||
public List<Layer> Layers { get; set; } | |||||
Dictionary<int, int> tensor_usage_count; | |||||
public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | |||||
public Functional(Tensors inputs, Tensors outputs) | public Functional(Tensors inputs, Tensors outputs) | ||||
: base(new ModelArgs | : base(new ModelArgs | ||||
@@ -33,6 +39,7 @@ namespace Tensorflow.Keras.Engine | |||||
_output_layers = new List<Layer>(); | _output_layers = new List<Layer>(); | ||||
_input_coordinates = new List<KerasHistory>(); | _input_coordinates = new List<KerasHistory>(); | ||||
_output_coordinates = new List<KerasHistory>(); | _output_coordinates = new List<KerasHistory>(); | ||||
tensor_usage_count = new Dictionary<int, int>(); | |||||
_init_graph_network(inputs, outputs); | _init_graph_network(inputs, outputs); | ||||
} | } | ||||
@@ -67,16 +74,253 @@ namespace Tensorflow.Keras.Engine | |||||
_input_layers.append(layer); | _input_layers.append(layer); | ||||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | ||||
} | } | ||||
// Keep track of the network's nodes and layers. | |||||
var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||||
NetworkNodes = nodes; | |||||
NodesByDepth = nodes_by_depth; | |||||
Layers = layers; | |||||
ComputeTensorUsageCount(); | |||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
void ComputeTensorUsageCount() | |||||
{ | { | ||||
return run_internal_graph(inputs, state, is_training); | |||||
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | |||||
var depth_keys = NodesByDepth.Keys.Reverse().Skip(1).ToArray(); | |||||
foreach(var depth in depth_keys) | |||||
{ | |||||
foreach(var node in NodesByDepth[depth]) | |||||
{ | |||||
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||||
if (input_tensors.issubset(available_tensors)) | |||||
{ | |||||
foreach (var tensor in node.KerasInputs) | |||||
{ | |||||
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||||
tensor_usage_count[tensor.GetHashCode()] = 0; | |||||
tensor_usage_count[tensor.GetHashCode()] += 1; | |||||
} | |||||
foreach (var output_tensor in node.Outputs) | |||||
available_tensors.Add(output_tensor.GetHashCode()); | |||||
} | |||||
} | |||||
} | |||||
foreach (var tensor in outputs) | |||||
{ | |||||
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||||
tensor_usage_count[tensor.GetHashCode()] = 0; | |||||
tensor_usage_count[tensor.GetHashCode()] += 1; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Validates a network's topology and gather its layers and nodes. | |||||
/// </summary> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="outputs"></param> | |||||
(string[], Dictionary<int, List<Node>>, List<Layer>, Dictionary<int, List<Layer>>) MapGraphNetwork(Tensors inputs, Tensors outputs) | |||||
{ | |||||
var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); | |||||
var network_nodes = nodes_in_decreasing_depth | |||||
.Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) | |||||
.ToArray(); | |||||
var nodes_depths = new Dictionary<Node, int>(); | |||||
var layers_depths = new Dictionary<Layer, int>(); | |||||
nodes_in_decreasing_depth.Reverse(); | |||||
foreach (var node in nodes_in_decreasing_depth) | |||||
{ | |||||
// If the depth is not set, the node has no outbound nodes (depth 0). | |||||
int depth = nodes_depths.SetDefault(node, 0); | |||||
// Update the depth of the corresponding layer | |||||
int previous_depth = layers_depths.Get(node.Layer, 0); | |||||
// If we've seen this layer before at a higher depth, | |||||
// we should use that depth instead of the node depth. | |||||
// This is necessary for shared layers that have inputs at different | |||||
// depth levels in the graph. | |||||
depth = Math.Max(depth, previous_depth); | |||||
layers_depths[node.Layer] = depth; | |||||
nodes_depths[node] = depth; | |||||
// Update the depth of inbound nodes. | |||||
// The "depth" of a node is the max of the depths | |||||
// of all nodes it is connected to + 1. | |||||
foreach(var node_dep in node.ParentNodes) | |||||
{ | |||||
previous_depth = nodes_depths.Get(node_dep, 0); | |||||
nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); | |||||
} | |||||
} | |||||
// Handle inputs that are not connected to outputs. | |||||
// We do not error out here because the inputs may be used to compute losses | |||||
// and metrics. | |||||
foreach(var input_t in inputs) | |||||
{ | |||||
var (input_layer, _, _) = input_t.KerasHistory; | |||||
if (!layers_depths.ContainsKey(input_layer)) | |||||
{ | |||||
layers_depths[input_layer] = 0; | |||||
layer_indices[input_layer] = -1; | |||||
nodes_depths[input_layer.InboundNodes[0]] = 0; | |||||
network_nodes.add(MakeNodeKey(input_layer.Name, 0)); | |||||
} | |||||
} | |||||
// Build a dict {depth: list of nodes with this depth} | |||||
var nodes_by_depth = new Dictionary<int, List<Node>>(); | |||||
foreach (var node in nodes_depths) | |||||
{ | |||||
if (!nodes_by_depth.ContainsKey(node.Value)) | |||||
nodes_by_depth[node.Value] = new List<Node>(); | |||||
nodes_by_depth[node.Value].append(node.Key); | |||||
} | |||||
var layers_by_depth = new Dictionary<int, List<Layer>>(); | |||||
foreach (var layer in layers_depths) | |||||
{ | |||||
if (!layers_by_depth.ContainsKey(layer.Value)) | |||||
layers_by_depth[layer.Value] = new List<Layer>(); | |||||
layers_by_depth[layer.Value].append(layer.Key); | |||||
} | |||||
// Get sorted list of layer depths. | |||||
var depth_keys = layers_by_depth.Keys.Reverse(); | |||||
// Set self.layers ordered by depth. | |||||
var layers = new List<Layer>(); | |||||
foreach(var depth in depth_keys) | |||||
{ | |||||
var layers_for_depth = layers_by_depth[depth]; | |||||
// Network.layers needs to have a deterministic order: | |||||
// here we order them by traversal order. | |||||
layers_for_depth.Reverse(); | |||||
layers.AddRange(layers_for_depth); | |||||
} | |||||
// Get sorted list of node depths. | |||||
depth_keys = nodes_by_depth.Keys.Reverse(); | |||||
return (network_nodes, nodes_by_depth, layers, layers_by_depth); | |||||
} | } | ||||
Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
string MakeNodeKey(string layer_name, int node_index) | |||||
=> $"{layer_name}_ib-{node_index}"; | |||||
/// <summary> | |||||
/// This method topologically sorts nodes in order from inputs to outputs. | |||||
/// </summary> | |||||
/// <param name="outputs"></param> | |||||
(List<Node>, Dictionary<Layer, int>) BuildMap(Tensors outputs) | |||||
{ | { | ||||
var finished_nodes = new List<Node>(); | |||||
var nodes_in_progress = new List<Node>(); | |||||
var nodes_in_decreasing_depth = new List<Node>(); | |||||
var layer_indices = new Dictionary<Layer, int>(); | |||||
foreach (var output in outputs) | |||||
BuildMapHelper(output, | |||||
finished_nodes, | |||||
nodes_in_progress, | |||||
nodes_in_decreasing_depth, | |||||
layer_indices); | |||||
return (nodes_in_decreasing_depth, layer_indices); | |||||
} | |||||
void BuildMapHelper(Tensor tensor, | |||||
List<Node> finished_nodes, | |||||
List<Node> nodes_in_progress, | |||||
List<Node> nodes_in_decreasing_depth, | |||||
Dictionary<Layer, int> layer_indices) | |||||
{ | |||||
var (layer, node_index, _) = tensor.KerasHistory; | |||||
var node = layer.InboundNodes[node_index]; | |||||
// Don't repeat work for shared subgraphs | |||||
if (finished_nodes.Contains(node)) | |||||
return; | |||||
// Prevent cycles. | |||||
if (nodes_in_progress.Contains(node)) | |||||
throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); | |||||
// Store the traversal order for layer sorting. | |||||
if (!layer_indices.ContainsKey(layer)) | |||||
layer_indices[layer] = layer_indices.Count; | |||||
// Propagate to all previous tensors connected to this node. | |||||
nodes_in_progress.Add(node); | |||||
foreach (var k_tensor in node.KerasInputs) | |||||
BuildMapHelper(k_tensor, | |||||
finished_nodes, | |||||
nodes_in_progress, | |||||
nodes_in_decreasing_depth, | |||||
layer_indices); | |||||
finished_nodes.Add(node); | |||||
nodes_in_progress.Remove(node); | |||||
nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); | |||||
} | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | |||||
return run_internal_graph(inputs, is_training); | |||||
} | |||||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||||
{ | |||||
if (mask != null) | |||||
{ | |||||
Tensor[] masks = new Tensor[inputs.Count()]; | |||||
foreach (var (i, input_t) in enumerate(inputs)) | |||||
input_t.KerasMask = masks[i]; | |||||
} | |||||
var tensor_dict = new Dictionary<int, Tensor[]>(); | |||||
foreach (var (x, y) in zip(this.inputs, inputs)) | |||||
{ | |||||
var y1 = conform_to_reference_input(y, x); | |||||
var x_id = x.GetHashCode(); | |||||
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray(); | |||||
} | |||||
var depth_keys = NodesByDepth.Keys.Reverse().ToArray(); | |||||
foreach(var depth in depth_keys) | |||||
{ | |||||
var nodes = NodesByDepth[depth]; | |||||
foreach(var node in nodes) | |||||
{ | |||||
// Input tensors already exist. | |||||
if (node.IsInput) | |||||
continue; | |||||
var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]); | |||||
tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; | |||||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||||
// Update tensor_dict. | |||||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||||
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); | |||||
} | |||||
} | |||||
foreach(var x in outputs) | |||||
{ | |||||
} | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||||
{ | |||||
return tensor; | |||||
} | |||||
} | } | ||||
} | } |
@@ -9,10 +9,10 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
public class KerasHistory | public class KerasHistory | ||||
{ | { | ||||
public Layer layer; | |||||
Layer layer; | |||||
int node_index; | int node_index; | ||||
int tensor_index; | int tensor_index; | ||||
public Tensor tensor; | |||||
Tensor tensor; | |||||
public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | ||||
{ | { | ||||
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||||
if (!built) | if (!built) | ||||
MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
outputs = call_fn(inputs, state: state, is_training: is_training); | |||||
outputs = CallFn(inputs, state: state, is_training: is_training); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
_handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||||
if (!dynamic) | if (!dynamic) | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
outputs = call_fn(inputs); | |||||
outputs = CallFn(inputs); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
_handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
@@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <param name="is_training"></param> | /// <param name="is_training"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -39,20 +39,42 @@ namespace Tensorflow.Keras.Engine | |||||
public Tensors Outputs => args.Outputs; | public Tensors Outputs => args.Outputs; | ||||
public TensorShape[] input_shapes; | public TensorShape[] input_shapes; | ||||
public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
List<Tensor> kerasInputs = new List<Tensor>(); | |||||
public List<Tensor> KerasInputs = new List<Tensor>(); | |||||
public Layer Layer { get; set; } | |||||
public bool IsInput => args.InputTensors == null; | |||||
public int[] FlatInputIds { get; set; } | |||||
public int[] FlatOutputIds { get; set; } | |||||
public Node[] ParentNodes | |||||
{ | |||||
get | |||||
{ | |||||
var node_deps = new List<Node>(); | |||||
foreach(var kt in KerasInputs) | |||||
{ | |||||
var (layer, node_index, _) = kt.KerasHistory; | |||||
if (layer != null) | |||||
node_deps.append(layer.InboundNodes[node_index]); | |||||
} | |||||
return node_deps.ToArray(); | |||||
} | |||||
} | |||||
public Node(Layer layer, NodeArgs args) | public Node(Layer layer, NodeArgs args) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
this.Layer = layer; | |||||
if (args.InputTensors != null) | if (args.InputTensors != null) | ||||
kerasInputs.AddRange(args.InputTensors); | |||||
KerasInputs.AddRange(args.InputTensors); | |||||
// Wire up Node to Layers. | // Wire up Node to Layers. | ||||
layer.InboundNodes.Add(this); | layer.InboundNodes.Add(this); | ||||
foreach (var kt in kerasInputs) | |||||
foreach (var kt in KerasInputs) | |||||
{ | { | ||||
var inbound_layer = kt.KerasHistory.layer; | |||||
if (kt.KerasHistory == null) | |||||
continue; | |||||
var (inbound_layer, _, _) = kt.KerasHistory; | |||||
if (inbound_layer != null) | if (inbound_layer != null) | ||||
inbound_layer.OutboundNodes.Add(this); | inbound_layer.OutboundNodes.Add(this); | ||||
} | } | ||||
@@ -61,6 +83,10 @@ namespace Tensorflow.Keras.Engine | |||||
var node_index = layer.InboundNodes.Count - 1; | var node_index = layer.InboundNodes.Count - 1; | ||||
foreach (var (i, tensor) in enumerate(Outputs)) | foreach (var (i, tensor) in enumerate(Outputs)) | ||||
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | ||||
// Cached for performance. | |||||
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||||
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
return base.call_fn(inputs, state, is_training); | |||||
return base.CallFn(inputs, state, is_training); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||||
{ | { | ||||
var outputs = _convolution_op.Apply(inputs, kernel); | var outputs = _convolution_op.Apply(inputs, kernel); | ||||
if (use_bias) | if (use_bias) | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var rank = inputs.rank; | var rank = inputs.rank; | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
var output = tf_utils.smart_cond(is_training, | var output = tf_utils.smart_cond(is_training, | ||||
() => tf.nn.dropout(inputs, | () => tf.nn.dropout(inputs, | ||||
@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||||
.ToArray(); | .ToArray(); | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
return base.call_fn(inputs, state: state, is_training: is_training); | |||||
return base.CallFn(inputs, state: state, is_training: is_training); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
input_spec = new InputSpec(ndim: 4); | input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
int[] pool_shape; | int[] pool_shape; | ||||
int[] strides; | int[] strides; | ||||
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
scale = math_ops.cast(args.Scale, args.DType); | scale = math_ops.cast(args.Scale, args.DType); | ||||
offset = math_ops.cast(args.Offset, args.DType); | offset = math_ops.cast(args.Offset, args.DType); | ||||
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.input_spec = new InputSpec(ndim: 4); | this.input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
return tf.keras.backend.spatial_2d_padding(inputs, | return tf.keras.backend.spatial_2d_padding(inputs, | ||||
padding: padding, | padding: padding, | ||||
@@ -74,7 +74,7 @@ namespace Tensorflow | |||||
/// <param name="training"></param> | /// <param name="training"></param> | ||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
var one = constant_op.constant(1, dtype: dtypes.int32); | var one = constant_op.constant(1, dtype: dtypes.int32); | ||||
// Parameters of gates are concatenated into one multiply for efficiency. | // Parameters of gates are concatenated into one multiply for efficiency. | ||||
@@ -67,7 +67,7 @@ namespace Tensorflow | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | { | ||||
// Most basic RNN: output = new_state = act(W * input + U * state + B). | // Most basic RNN: output = new_state = act(W * input + U * state + B). | ||||
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | ||||
@@ -145,6 +145,7 @@ namespace Tensorflow | |||||
/// Keras History: (Layer, (node_index, tensor_index)) | /// Keras History: (Layer, (node_index, tensor_index)) | ||||
/// </summary> | /// </summary> | ||||
public KerasHistory KerasHistory { get; set; } | public KerasHistory KerasHistory { get; set; } | ||||
public Tensor KerasMask { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||