Browse Source

fix _compute_fans when more than 4 dimensions.

tags/v0.8.0
haiping008 6 years ago
parent
commit
9f0c1e5fe2
7 changed files with 72 additions and 8 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  2. +5
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  3. +21
    -4
      src/TensorFlowNET.Core/Layers/Layer.cs
  4. +11
    -1
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  5. +16
    -1
      src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  7. +2
    -1
      src/TensorFlowNET.Core/Variables/_VariableStore.cs

+ 16
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -57,5 +57,21 @@ namespace Tensorflow.Keras.Engine
{

}

protected virtual void add_weight(string name,
int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
{
_add_variable_with_custom_getter(name,
shape,
dtype: dtype,
getter: getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -53,7 +53,11 @@ namespace Tensorflow.Keras.Layers
int channel_axis = data_format == "channels_first" ? 1 : -1;
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
add_weight();
add_weight(name: "kernel",
shape: kernel_shape,
initializer: kernel_initializer,
trainable: true,
dtype: _dtype);
}
}
}

+ 21
- 4
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -68,7 +68,11 @@ namespace Tensorflow.Layers
throw new NotImplementedException("");
}

protected virtual void add_weight()
protected virtual void add_weight(string name,
int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null)
{
var default_graph = ops.get_default_graph();
Graph init_graph = null;
@@ -84,7 +88,9 @@ namespace Tensorflow.Layers
existing_variables = variables.global_variables().ToArray();
}

var dtype = TF_DataType.TF_FLOAT;
if(dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT;

_set_scope();
var reuse = built || (_reuse != null && _reuse.Value);
Python.with(tf.variable_scope(_scope,
@@ -94,8 +100,19 @@ namespace Tensorflow.Layers
_current_scope = scope;
Python.with(ops.name_scope(_name_scope()), delegate
{


base.add_weight(name,
shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
getter: (name1, shape1, dtype1, initializer1, trainable1) =>
{
return tf.get_variable(name1,
shape: new TensorShape(shape1),
dtype: dtype1,
initializer: initializer1,
trainable: trainable1);
});
});
});
}


+ 11
- 1
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Operations.Initializers
@@ -64,7 +65,16 @@ namespace Tensorflow.Operations.Initializers
if (shape.Length == 2)
return (shape[0], shape[1]);
else
throw new NotImplementedException("VarianceScaling._compute_fans");
{
// Assuming convolution kernels (2D, 3D, or more).
// kernel shape: (..., input_depth, depth)
int receptive_field_size = 1;
foreach (var dim in shape.Take(2))
receptive_field_size *= dim;
var fan_in = shape[shape.Length - 2] * receptive_field_size;
var fan_out = shape[shape.Length - 1] * receptive_field_size;
return (fan_in, fan_out);
}
}

public virtual object get_config()


+ 16
- 1
src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs View File

@@ -4,7 +4,22 @@ using System.Text;

namespace Tensorflow
{
public class CheckpointableBase
public abstract class CheckpointableBase
{
/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary>
/// <returns></returns>
protected virtual RefVariable _add_variable_with_custom_getter(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
bool overwrite = false,
bool trainable = false)
{
var new_variable = getter(name, shape, dtype, initializer, trainable);
throw new NotImplementedException("_add_variable_with_custom_getter");
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -48,6 +48,7 @@ namespace Tensorflow
shape: shape,
dtype: dtype,
initializer: initializer,
reuse: resue,
trainable: trainable,
synchronization: synchronization,
aggregation: aggregation);


+ 2
- 1
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -24,6 +24,7 @@ namespace Tensorflow
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
object initializer = null, // IInitializer or Tensor
bool? reuse = null,
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -100,7 +101,7 @@ namespace Tensorflow
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
bool initializing_from_value = true;
bool initializing_from_value = false;
if (use_resource == null)
use_resource = false;



Loading…
Cancel
Save