@@ -13,10 +13,10 @@ namespace Tensorflow.Keras | |||||
List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); | Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); | ||||
List<IVariableV1> trainable_variables { get; } | |||||
List<IVariableV1> trainable_weights { get; } | |||||
List<IVariableV1> non_trainable_weights { get; } | |||||
Shape output_shape { get; } | |||||
List<IVariableV1> TrainableVariables { get; } | |||||
List<IVariableV1> TrainableWeights { get; } | |||||
List<IVariableV1> NonTrainableWeights { get; } | |||||
Shape OutputShape { get; } | |||||
Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
int count_params(); | int count_params(); | ||||
@@ -67,11 +67,11 @@ namespace Tensorflow | |||||
public bool Trainable => throw new NotImplementedException(); | public bool Trainable => throw new NotImplementedException(); | ||||
public List<IVariableV1> trainable_variables => throw new NotImplementedException(); | |||||
public List<IVariableV1> trainable_weights => throw new NotImplementedException(); | |||||
public List<IVariableV1> non_trainable_weights => throw new NotImplementedException(); | |||||
public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||||
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||||
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||||
public Shape output_shape => throw new NotImplementedException(); | |||||
public Shape OutputShape => throw new NotImplementedException(); | |||||
public Shape BatchInputShape => throw new NotImplementedException(); | public Shape BatchInputShape => throw new NotImplementedException(); | ||||
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine | |||||
}; | }; | ||||
var node_conversion_map = new Dictionary<string, int>(); | var node_conversion_map = new Dictionary<string, int>(); | ||||
foreach (var layer in _layers) | |||||
foreach (var layer in _self_tracked_trackables) | |||||
{ | { | ||||
var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | ||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
var layer_configs = new List<LayerConfig>(); | var layer_configs = new List<LayerConfig>(); | ||||
foreach (var layer in _layers) | |||||
foreach (var layer in _self_tracked_trackables) | |||||
{ | { | ||||
var filtered_inbound_nodes = new List<NodeConfig>(); | var filtered_inbound_nodes = new List<NodeConfig>(); | ||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | ||||
@@ -65,13 +65,8 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
// Keep track of the network's nodes and layers. | // Keep track of the network's nodes and layers. | ||||
var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||||
(NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); | |||||
NetworkNodes = nodes; | |||||
NodesByDepth = nodes_by_depth; | |||||
if (_layers.Count == 0) | |||||
_layers = layers; | |||||
_self_tracked_trackables = layers; | |||||
// Build self.input_names and self.output_names. | // Build self.input_names and self.output_names. | ||||
_set_output_names(); | _set_output_names(); | ||||
@@ -5,8 +5,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public partial class Layer | public partial class Layer | ||||
{ | { | ||||
protected List<ILayer> _layers = new List<ILayer>(); | |||||
public virtual List<ILayer> Layers => _layers; | |||||
public virtual List<ILayer> Layers => _self_tracked_trackables; | |||||
protected void StackLayers(params ILayer[] layers) | protected void StackLayers(params ILayer[] layers) | ||||
{ | { | ||||
@@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine | |||||
public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
public virtual List<IVariableV1> trainable_variables => _trainable_weights; | |||||
public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||||
protected List<IVariableV1> _non_trainable_weights; | protected List<IVariableV1> _non_trainable_weights; | ||||
public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | ||||
@@ -88,7 +88,7 @@ namespace Tensorflow.Keras.Engine | |||||
public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
public Tensor[] input => inboundNodes[0].input_tensors; | public Tensor[] input => inboundNodes[0].input_tensors; | ||||
public Dictionary<int, List<INode>> NodesByDepth { get; set; } | public Dictionary<int, List<INode>> NodesByDepth { get; set; } | ||||
public Shape output_shape => inboundNodes[0].Outputs.shape; | |||||
public Shape OutputShape => inboundNodes[0].Outputs.shape; | |||||
protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
@@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||||
return layer_utils.count_params(this, weights); | return layer_utils.count_params(this, weights); | ||||
return 0; | return 0; | ||||
} | } | ||||
List<IVariableV1> ILayer.trainable_weights | |||||
List<IVariableV1> ILayer.TrainableWeights | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
@@ -258,7 +258,7 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
List<IVariableV1> ILayer.non_trainable_weights | |||||
List<IVariableV1> ILayer.NonTrainableWeights | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Engine | |||||
// self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | ||||
// The _minimize call does a few extra steps unnecessary in most cases, | // The _minimize call does a few extra steps unnecessary in most cases, | ||||
// such as loss scaling and gradient clipping. | // such as loss scaling and gradient clipping. | ||||
_minimize(tape, optimizer, loss, trainable_variables); | |||||
_minimize(tape, optimizer, loss, TrainableVariables); | |||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return metrics.Select(x => (x.Name, x.result())).ToList(); | return metrics.Select(x => (x.Name, x.result())).ToList(); | ||||
@@ -74,7 +74,7 @@ namespace Tensorflow.Keras.Engine | |||||
public override List<ILayer> Layers | public override List<ILayer> Layers | ||||
=> _flatten_layers(recursive: false, include_self: false).ToList(); | => _flatten_layers(recursive: false, include_self: false).ToList(); | ||||
public override List<IVariableV1> trainable_variables | |||||
public override List<IVariableV1> TrainableVariables | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
@@ -88,13 +88,13 @@ namespace Tensorflow.Keras.Engine | |||||
foreach (var trackable_obj in _self_tracked_trackables) | foreach (var trackable_obj in _self_tracked_trackables) | ||||
{ | { | ||||
if (trackable_obj.Trainable) | if (trackable_obj.Trainable) | ||||
variables.AddRange(trackable_obj.trainable_variables); | |||||
variables.AddRange(trackable_obj.TrainableVariables); | |||||
} | } | ||||
foreach (var layer in _layers) | |||||
foreach (var layer in _self_tracked_trackables) | |||||
{ | { | ||||
if (layer.Trainable) | if (layer.Trainable) | ||||
variables.AddRange(layer.trainable_variables); | |||||
variables.AddRange(layer.TrainableVariables); | |||||
} | } | ||||
// variables.AddRange(_trainable_weights); | // variables.AddRange(_trainable_weights); | ||||
@@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
built = false; | built = false; | ||||
var set_inputs = false; | var set_inputs = false; | ||||
if (_layers.Count == 0) | |||||
if (_self_tracked_trackables.Count == 0) | |||||
{ | { | ||||
if (layer is InputLayer) | if (layer is InputLayer) | ||||
{ | { | ||||
@@ -128,7 +128,7 @@ namespace Tensorflow.Keras.Engine | |||||
void _handle_deferred_layer_dependencies(params ILayer[] layers) | void _handle_deferred_layer_dependencies(params ILayer[] layers) | ||||
{ | { | ||||
_layers.AddRange(layers); | |||||
_self_tracked_trackables.AddRange(layers); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
@@ -156,12 +156,12 @@ namespace Tensorflow.Keras.Engine | |||||
ops.init_scope(); | ops.init_scope(); | ||||
var inputs = keras.Input(batch_input_shape: input_shape, | var inputs = keras.Input(batch_input_shape: input_shape, | ||||
dtype: input_dtype, | dtype: input_dtype, | ||||
name: $"{_layers[0].Name}_input"); | |||||
name: $"{_self_tracked_trackables[0].Name}_input"); | |||||
Tensors layer_input = inputs; | Tensors layer_input = inputs; | ||||
Tensors layer_output = null; | Tensors layer_output = null; | ||||
Tensors outputs = null; | Tensors outputs = null; | ||||
List<INode> created_nodes = new List<INode>(); | List<INode> created_nodes = new List<INode>(); | ||||
foreach (var layer in _layers) | |||||
foreach (var layer in _self_tracked_trackables) | |||||
{ | { | ||||
clear_previously_created_nodes(layer, _created_nodes); | clear_previously_created_nodes(layer, _created_nodes); | ||||
layer_output = layer.Apply(layer_input); | layer_output = layer.Apply(layer_input); | ||||
@@ -338,8 +338,8 @@ namespace Tensorflow.Keras.Saving | |||||
public static List<IVariableV1> _legacy_weights(ILayer layer) | public static List<IVariableV1> _legacy_weights(ILayer layer) | ||||
{ | { | ||||
var weights = layer.trainable_weights.Select(x => x).ToList(); | |||||
weights.AddRange(layer.non_trainable_weights); | |||||
var weights = layer.TrainableWeights.Select(x => x).ToList(); | |||||
weights.AddRange(layer.NonTrainableWeights); | |||||
return weights; | return weights; | ||||
} | } | ||||
} | } | ||||
@@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Utils | |||||
print(string.Join("", range(line_length).Select(x => "_"))); | print(string.Join("", range(line_length).Select(x => "_"))); | ||||
} | } | ||||
var trainable_count = count_params(model, model.trainable_variables); | |||||
var trainable_count = count_params(model, model.TrainableVariables); | |||||
var non_trainable_count = count_params(model, model.non_trainable_variables); | var non_trainable_count = count_params(model, model.non_trainable_variables); | ||||
print($"Total params: {trainable_count + non_trainable_count}"); | print($"Total params: {trainable_count + non_trainable_count}"); | ||||
@@ -137,7 +137,7 @@ namespace Tensorflow.Keras.Utils | |||||
var fields = new string[] | var fields = new string[] | ||||
{ | { | ||||
$"{name} ({layer.GetType().Name})", | $"{name} ({layer.GetType().Name})", | ||||
$"{layer.output_shape}", | |||||
$"{layer.OutputShape}", | |||||
$"{layer.count_params()}" | $"{layer.count_params()}" | ||||
}; | }; | ||||
@@ -164,7 +164,7 @@ namespace Tensorflow.Keras.Utils | |||||
var fields = new string[] | var fields = new string[] | ||||
{ | { | ||||
$"{name}({layer.GetType().Name})", | $"{name}({layer.GetType().Name})", | ||||
$"{layer.output_shape}", | |||||
$"{layer.OutputShape}", | |||||
$"{layer.count_params()}", | $"{layer.count_params()}", | ||||
first_connection | first_connection | ||||
}; | }; | ||||