@@ -22,6 +22,7 @@ using System.ComponentModel; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using NumSharp.Utilities; | |||
using System.Runtime.CompilerServices; | |||
namespace Tensorflow | |||
{ | |||
@@ -50,7 +51,7 @@ namespace Tensorflow | |||
=> list.Add(element); | |||
public static void append<T>(this IList<T> list, T element) | |||
=> list.Add(element); | |||
=> list.Insert(list.Count, element); | |||
public static T[] concat<T>(this IList<T> list1, IList<T> list2) | |||
{ | |||
@@ -407,5 +408,37 @@ namespace Tensorflow | |||
return true; | |||
return false; | |||
} | |||
public static bool issubset<T>(this IEnumerable<T> subset, IEnumerable<T> src) | |||
{ | |||
bool issubset = true; | |||
foreach (var element in subset) | |||
{ | |||
if (!src.Contains(element)) | |||
{ | |||
issubset = false; | |||
continue; | |||
} | |||
} | |||
return true; | |||
} | |||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
{ | |||
if (dic.ContainsKey(key)) | |||
return dic[key]; | |||
dic[key] = value; | |||
return value; | |||
} | |||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
{ | |||
if (dic.ContainsKey(key)) | |||
return dic[key]; | |||
return value; | |||
} | |||
} | |||
} |
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
_channels_first = args.DataFormat == "channels_first"; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
if (_channels_first) | |||
{ | |||
@@ -4,6 +4,7 @@ using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -21,6 +22,11 @@ namespace Tensorflow.Keras.Engine | |||
List<Layer> _input_layers; | |||
List<KerasHistory> _input_coordinates; | |||
List<KerasHistory> _output_coordinates; | |||
public string[] NetworkNodes { get; set; } | |||
public Dictionary<int, List<Node>> NodesByDepth { get; set; } | |||
public List<Layer> Layers { get; set; } | |||
Dictionary<int, int> tensor_usage_count; | |||
public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | |||
public Functional(Tensors inputs, Tensors outputs) | |||
: base(new ModelArgs | |||
@@ -33,6 +39,7 @@ namespace Tensorflow.Keras.Engine | |||
_output_layers = new List<Layer>(); | |||
_input_coordinates = new List<KerasHistory>(); | |||
_output_coordinates = new List<KerasHistory>(); | |||
tensor_usage_count = new Dictionary<int, int>(); | |||
_init_graph_network(inputs, outputs); | |||
} | |||
@@ -67,16 +74,253 @@ namespace Tensorflow.Keras.Engine | |||
_input_layers.append(layer); | |||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
} | |||
// Keep track of the network's nodes and layers. | |||
var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||
NetworkNodes = nodes; | |||
NodesByDepth = nodes_by_depth; | |||
Layers = layers; | |||
ComputeTensorUsageCount(); | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
void ComputeTensorUsageCount() | |||
{ | |||
return run_internal_graph(inputs, state, is_training); | |||
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | |||
var depth_keys = NodesByDepth.Keys.Reverse().Skip(1).ToArray(); | |||
foreach(var depth in depth_keys) | |||
{ | |||
foreach(var node in NodesByDepth[depth]) | |||
{ | |||
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||
if (input_tensors.issubset(available_tensors)) | |||
{ | |||
foreach (var tensor in node.KerasInputs) | |||
{ | |||
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||
tensor_usage_count[tensor.GetHashCode()] = 0; | |||
tensor_usage_count[tensor.GetHashCode()] += 1; | |||
} | |||
foreach (var output_tensor in node.Outputs) | |||
available_tensors.Add(output_tensor.GetHashCode()); | |||
} | |||
} | |||
} | |||
foreach (var tensor in outputs) | |||
{ | |||
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||
tensor_usage_count[tensor.GetHashCode()] = 0; | |||
tensor_usage_count[tensor.GetHashCode()] += 1; | |||
} | |||
} | |||
/// <summary> | |||
/// Validates a network's topology and gather its layers and nodes. | |||
/// </summary> | |||
/// <param name="inputs"></param> | |||
/// <param name="outputs"></param> | |||
(string[], Dictionary<int, List<Node>>, List<Layer>, Dictionary<int, List<Layer>>) MapGraphNetwork(Tensors inputs, Tensors outputs) | |||
{ | |||
var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); | |||
var network_nodes = nodes_in_decreasing_depth | |||
.Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) | |||
.ToArray(); | |||
var nodes_depths = new Dictionary<Node, int>(); | |||
var layers_depths = new Dictionary<Layer, int>(); | |||
nodes_in_decreasing_depth.Reverse(); | |||
foreach (var node in nodes_in_decreasing_depth) | |||
{ | |||
// If the depth is not set, the node has no outbound nodes (depth 0). | |||
int depth = nodes_depths.SetDefault(node, 0); | |||
// Update the depth of the corresponding layer | |||
int previous_depth = layers_depths.Get(node.Layer, 0); | |||
// If we've seen this layer before at a higher depth, | |||
// we should use that depth instead of the node depth. | |||
// This is necessary for shared layers that have inputs at different | |||
// depth levels in the graph. | |||
depth = Math.Max(depth, previous_depth); | |||
layers_depths[node.Layer] = depth; | |||
nodes_depths[node] = depth; | |||
// Update the depth of inbound nodes. | |||
// The "depth" of a node is the max of the depths | |||
// of all nodes it is connected to + 1. | |||
foreach(var node_dep in node.ParentNodes) | |||
{ | |||
previous_depth = nodes_depths.Get(node_dep, 0); | |||
nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); | |||
} | |||
} | |||
// Handle inputs that are not connected to outputs. | |||
// We do not error out here because the inputs may be used to compute losses | |||
// and metrics. | |||
foreach(var input_t in inputs) | |||
{ | |||
var (input_layer, _, _) = input_t.KerasHistory; | |||
if (!layers_depths.ContainsKey(input_layer)) | |||
{ | |||
layers_depths[input_layer] = 0; | |||
layer_indices[input_layer] = -1; | |||
nodes_depths[input_layer.InboundNodes[0]] = 0; | |||
network_nodes.add(MakeNodeKey(input_layer.Name, 0)); | |||
} | |||
} | |||
// Build a dict {depth: list of nodes with this depth} | |||
var nodes_by_depth = new Dictionary<int, List<Node>>(); | |||
foreach (var node in nodes_depths) | |||
{ | |||
if (!nodes_by_depth.ContainsKey(node.Value)) | |||
nodes_by_depth[node.Value] = new List<Node>(); | |||
nodes_by_depth[node.Value].append(node.Key); | |||
} | |||
var layers_by_depth = new Dictionary<int, List<Layer>>(); | |||
foreach (var layer in layers_depths) | |||
{ | |||
if (!layers_by_depth.ContainsKey(layer.Value)) | |||
layers_by_depth[layer.Value] = new List<Layer>(); | |||
layers_by_depth[layer.Value].append(layer.Key); | |||
} | |||
// Get sorted list of layer depths. | |||
var depth_keys = layers_by_depth.Keys.Reverse(); | |||
// Set self.layers ordered by depth. | |||
var layers = new List<Layer>(); | |||
foreach(var depth in depth_keys) | |||
{ | |||
var layers_for_depth = layers_by_depth[depth]; | |||
// Network.layers needs to have a deterministic order: | |||
// here we order them by traversal order. | |||
layers_for_depth.Reverse(); | |||
layers.AddRange(layers_for_depth); | |||
} | |||
// Get sorted list of node depths. | |||
depth_keys = nodes_by_depth.Keys.Reverse(); | |||
return (network_nodes, nodes_by_depth, layers, layers_by_depth); | |||
} | |||
Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false) | |||
string MakeNodeKey(string layer_name, int node_index) | |||
=> $"{layer_name}_ib-{node_index}"; | |||
/// <summary> | |||
/// This method topologically sorts nodes in order from inputs to outputs. | |||
/// </summary> | |||
/// <param name="outputs"></param> | |||
(List<Node>, Dictionary<Layer, int>) BuildMap(Tensors outputs) | |||
{ | |||
var finished_nodes = new List<Node>(); | |||
var nodes_in_progress = new List<Node>(); | |||
var nodes_in_decreasing_depth = new List<Node>(); | |||
var layer_indices = new Dictionary<Layer, int>(); | |||
foreach (var output in outputs) | |||
BuildMapHelper(output, | |||
finished_nodes, | |||
nodes_in_progress, | |||
nodes_in_decreasing_depth, | |||
layer_indices); | |||
return (nodes_in_decreasing_depth, layer_indices); | |||
} | |||
void BuildMapHelper(Tensor tensor, | |||
List<Node> finished_nodes, | |||
List<Node> nodes_in_progress, | |||
List<Node> nodes_in_decreasing_depth, | |||
Dictionary<Layer, int> layer_indices) | |||
{ | |||
var (layer, node_index, _) = tensor.KerasHistory; | |||
var node = layer.InboundNodes[node_index]; | |||
// Don't repeat work for shared subgraphs | |||
if (finished_nodes.Contains(node)) | |||
return; | |||
// Prevent cycles. | |||
if (nodes_in_progress.Contains(node)) | |||
throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); | |||
// Store the traversal order for layer sorting. | |||
if (!layer_indices.ContainsKey(layer)) | |||
layer_indices[layer] = layer_indices.Count; | |||
// Propagate to all previous tensors connected to this node. | |||
nodes_in_progress.Add(node); | |||
foreach (var k_tensor in node.KerasInputs) | |||
BuildMapHelper(k_tensor, | |||
finished_nodes, | |||
nodes_in_progress, | |||
nodes_in_decreasing_depth, | |||
layer_indices); | |||
finished_nodes.Add(node); | |||
nodes_in_progress.Remove(node); | |||
nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); | |||
} | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return run_internal_graph(inputs, is_training); | |||
} | |||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||
{ | |||
if (mask != null) | |||
{ | |||
Tensor[] masks = new Tensor[inputs.Count()]; | |||
foreach (var (i, input_t) in enumerate(inputs)) | |||
input_t.KerasMask = masks[i]; | |||
} | |||
var tensor_dict = new Dictionary<int, Tensor[]>(); | |||
foreach (var (x, y) in zip(this.inputs, inputs)) | |||
{ | |||
var y1 = conform_to_reference_input(y, x); | |||
var x_id = x.GetHashCode(); | |||
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray(); | |||
} | |||
var depth_keys = NodesByDepth.Keys.Reverse().ToArray(); | |||
foreach(var depth in depth_keys) | |||
{ | |||
var nodes = NodesByDepth[depth]; | |||
foreach(var node in nodes) | |||
{ | |||
// Input tensors already exist. | |||
if (node.IsInput) | |||
continue; | |||
var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]); | |||
tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; | |||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
// Update tensor_dict. | |||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); | |||
} | |||
} | |||
foreach(var x in outputs) | |||
{ | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||
{ | |||
return tensor; | |||
} | |||
} | |||
} |
@@ -9,10 +9,10 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
public class KerasHistory | |||
{ | |||
public Layer layer; | |||
Layer layer; | |||
int node_index; | |||
int tensor_index; | |||
public Tensor tensor; | |||
Tensor tensor; | |||
public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | |||
{ | |||
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||
if (!built) | |||
MaybeBuild(inputs); | |||
outputs = call_fn(inputs, state: state, is_training: is_training); | |||
outputs = CallFn(inputs, state: state, is_training: is_training); | |||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||
_handle_activity_regularization(inputs, outputs); | |||
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||
if (!dynamic) | |||
throw new NotImplementedException(""); | |||
outputs = call_fn(inputs); | |||
outputs = CallFn(inputs); | |||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||
_handle_activity_regularization(inputs, outputs); | |||
@@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine | |||
/// <param name="state"></param> | |||
/// <param name="is_training"></param> | |||
/// <returns></returns> | |||
protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -39,20 +39,42 @@ namespace Tensorflow.Keras.Engine | |||
public Tensors Outputs => args.Outputs; | |||
public TensorShape[] input_shapes; | |||
public TensorShape[] output_shapes; | |||
List<Tensor> kerasInputs = new List<Tensor>(); | |||
public List<Tensor> KerasInputs = new List<Tensor>(); | |||
public Layer Layer { get; set; } | |||
public bool IsInput => args.InputTensors == null; | |||
public int[] FlatInputIds { get; set; } | |||
public int[] FlatOutputIds { get; set; } | |||
public Node[] ParentNodes | |||
{ | |||
get | |||
{ | |||
var node_deps = new List<Node>(); | |||
foreach(var kt in KerasInputs) | |||
{ | |||
var (layer, node_index, _) = kt.KerasHistory; | |||
if (layer != null) | |||
node_deps.append(layer.InboundNodes[node_index]); | |||
} | |||
return node_deps.ToArray(); | |||
} | |||
} | |||
public Node(Layer layer, NodeArgs args) | |||
{ | |||
this.args = args; | |||
this.Layer = layer; | |||
if (args.InputTensors != null) | |||
kerasInputs.AddRange(args.InputTensors); | |||
KerasInputs.AddRange(args.InputTensors); | |||
// Wire up Node to Layers. | |||
layer.InboundNodes.Add(this); | |||
foreach (var kt in kerasInputs) | |||
foreach (var kt in KerasInputs) | |||
{ | |||
var inbound_layer = kt.KerasHistory.layer; | |||
if (kt.KerasHistory == null) | |||
continue; | |||
var (inbound_layer, _, _) = kt.KerasHistory; | |||
if (inbound_layer != null) | |||
inbound_layer.OutboundNodes.Add(this); | |||
} | |||
@@ -61,6 +83,10 @@ namespace Tensorflow.Keras.Engine | |||
var node_index = layer.InboundNodes.Count - 1; | |||
foreach (var (i, tensor) in enumerate(Outputs)) | |||
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | |||
// Cached for performance. | |||
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray(); | |||
} | |||
} | |||
} |
@@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return base.call_fn(inputs, state, is_training); | |||
return base.CallFn(inputs, state, is_training); | |||
} | |||
} | |||
} |
@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
Tensor outputs = null; | |||
@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
var outputs = _convolution_op.Apply(inputs, kernel); | |||
if (use_bias) | |||
@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
Tensor outputs = null; | |||
var rank = inputs.rank; | |||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||
this.args = args; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var output = tf_utils.smart_cond(is_training, | |||
() => tf.nn.dropout(inputs, | |||
@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var dtype = inputs.dtype; | |||
if (dtype != tf.int32 && dtype != tf.int64) | |||
@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||
.ToArray(); | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return base.call_fn(inputs, state: state, is_training: is_training); | |||
return base.CallFn(inputs, state: state, is_training: is_training); | |||
} | |||
} | |||
} |
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||
input_spec = new InputSpec(ndim: 4); | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
int[] pool_shape; | |||
int[] strides; | |||
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
this.args = args; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
scale = math_ops.cast(args.Scale, args.DType); | |||
offset = math_ops.cast(args.Offset, args.DType); | |||
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers | |||
this.input_spec = new InputSpec(ndim: 4); | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return tf.keras.backend.spatial_2d_padding(inputs, | |||
padding: padding, | |||
@@ -74,7 +74,7 @@ namespace Tensorflow | |||
/// <param name="training"></param> | |||
/// <param name="state"></param> | |||
/// <returns></returns> | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var one = constant_op.constant(1, dtype: dtypes.int32); | |||
// Parameters of gates are concatenated into one multiply for efficiency. | |||
@@ -67,7 +67,7 @@ namespace Tensorflow | |||
built = true; | |||
} | |||
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
// Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | |||
@@ -145,6 +145,7 @@ namespace Tensorflow | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public KerasHistory KerasHistory { get; set; } | |||
public Tensor KerasMask { get; set; } | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||