From 9f0c1e5fe24c1dba077b7a622f38aa007deb1ac8 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 8 Mar 2019 17:41:14 -0600 Subject: [PATCH] fix _compute_fans when more than 4 dimensions. --- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 16 ++++++++++++ src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 6 ++++- src/TensorFlowNET.Core/Layers/Layer.cs | 25 ++++++++++++++++--- .../Initializers/VarianceScaling.cs | 12 ++++++++- .../Checkpointable/CheckpointableBase.cs | 17 ++++++++++++- .../Variables/VariableScope.cs | 1 + .../Variables/_VariableStore.cs | 3 ++- 7 files changed, 72 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 941d6904..3c5825e2 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -57,5 +57,21 @@ namespace Tensorflow.Keras.Engine { } + + protected virtual void add_weight(string name, + int[] shape, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool? trainable = null, + Func getter = null) + { + _add_variable_with_custom_getter(name, + shape, + dtype: dtype, + getter: getter, + overwrite: true, + initializer: initializer, + trainable: trainable.Value); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 9d674165..fdd4329a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -53,7 +53,11 @@ namespace Tensorflow.Keras.Layers int channel_axis = data_format == "channels_first" ? 1 : -1; int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; - add_weight(); + add_weight(name: "kernel", + shape: kernel_shape, + initializer: kernel_initializer, + trainable: true, + dtype: _dtype); } } } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 128d0af1..b8e618f1 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -68,7 +68,11 @@ namespace Tensorflow.Layers throw new NotImplementedException(""); } - protected virtual void add_weight() + protected virtual void add_weight(string name, + int[] shape, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool? trainable = null) { var default_graph = ops.get_default_graph(); Graph init_graph = null; @@ -84,7 +88,9 @@ namespace Tensorflow.Layers existing_variables = variables.global_variables().ToArray(); } - var dtype = TF_DataType.TF_FLOAT; + if(dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + _set_scope(); var reuse = built || (_reuse != null && _reuse.Value); Python.with(tf.variable_scope(_scope, @@ -94,8 +100,19 @@ namespace Tensorflow.Layers _current_scope = scope; Python.with(ops.name_scope(_name_scope()), delegate { - - + base.add_weight(name, + shape, + dtype: dtype, + initializer: initializer, + trainable: trainable, + getter: (name1, shape1, dtype1, initializer1, trainable1) => + { + return tf.get_variable(name1, + shape: new TensorShape(shape1), + dtype: dtype1, + initializer: initializer1, + trainable: trainable1); + }); }); }); } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index 0fcaf392..7a8d9af8 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow.Operations.Initializers @@ -64,7 +65,16 @@ namespace Tensorflow.Operations.Initializers if (shape.Length == 2) return (shape[0], shape[1]); else - throw new NotImplementedException("VarianceScaling._compute_fans"); + { + // Assuming convolution kernels (2D, 3D, or more). + // kernel shape: (..., input_depth, depth) + int receptive_field_size = 1; + foreach (var dim in shape.Take(2)) + receptive_field_size *= dim; + var fan_in = shape[shape.Length - 2] * receptive_field_size; + var fan_out = shape[shape.Length - 1] * receptive_field_size; + return (fan_in, fan_out); + } } public virtual object get_config() diff --git a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs index 632278fe..558a7177 100644 --- a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs +++ b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs @@ -4,7 +4,22 @@ using System.Text; namespace Tensorflow { - public class CheckpointableBase + public abstract class CheckpointableBase { + /// + /// Restore-on-create for a variable be saved with this `Checkpointable`. + /// + /// + protected virtual RefVariable _add_variable_with_custom_getter(string name, + int[] shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + Func getter = null, + bool overwrite = false, + bool trainable = false) + { + var new_variable = getter(name, shape, dtype, initializer, trainable); + throw new NotImplementedException("_add_variable_with_custom_getter"); + } } } diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 33e4bcfc..7a2ef841 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -48,6 +48,7 @@ namespace Tensorflow shape: shape, dtype: dtype, initializer: initializer, + reuse: resue, trainable: trainable, synchronization: synchronization, aggregation: aggregation); diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index fd00ee0c..33365cbd 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -24,6 +24,7 @@ namespace Tensorflow TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, object initializer = null, // IInitializer or Tensor + bool? reuse = null, bool? trainable = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.AUTO, @@ -100,7 +101,7 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { - bool initializing_from_value = true; + bool initializing_from_value = false; if (use_resource == null) use_resource = false;