Browse Source

Dese layer

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
7ebc2b2b77
3 changed files with 31 additions and 5 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  2. +28
    -2
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs

+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
public class InputSpec
{
public int ndim;
public int? ndim;
public int? min_ndim;
Dictionary<int, int> axes;

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Engine
int? min_ndim = null,
Dictionary<int, int> axes = null)
{
this.ndim = ndim.Value;
this.ndim = ndim;
if (axes == null)
axes = new Dictionary<int, int>();
this.axes = axes;


+ 28
- 2
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations.Activation;
@@ -8,11 +9,13 @@ namespace Tensorflow.Keras.Layers
{
public class Dense : Tensorflow.Layers.Layer
{
protected int uints;
protected int units;
protected IActivation activation;
protected bool use_bias;
protected IInitializer kernel_initializer;
protected IInitializer bias_initializer;
protected RefVariable kernel;
protected RefVariable bias;

public Dense(int units,
IActivation activation,
@@ -21,7 +24,7 @@ namespace Tensorflow.Keras.Layers
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null) : base(trainable: trainable)
{
this.uints = units;
this.units = units;
this.activation = activation;
this.use_bias = use_bias;
this.kernel_initializer = kernel_initializer;
@@ -29,5 +32,28 @@ namespace Tensorflow.Keras.Layers
this.supports_masking = true;
this.input_spec = new InputSpec(min_ndim: 2);
}

protected override void build(TensorShape input_shape)
{
var last_dim = input_shape.Dimensions.Last();
var axes = new Dictionary<int, int>();
axes[-1] = last_dim;
input_spec = new InputSpec(min_ndim: 2, axes: axes);
kernel = add_weight(
"kernel",
shape: new int[] { last_dim, units },
initializer: kernel_initializer,
dtype: _dtype,
trainable: true);
if (use_bias)
bias = add_weight(
"bias",
shape: new int[] { units },
initializer: bias_initializer,
dtype: _dtype,
trainable: true);

built = true;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow.Operations.Initializers

public Tensor call(TensorShape shape, TF_DataType dtype)
{
throw new NotImplementedException("");
return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed);
}

public object get_config()


Loading…
Cancel
Save