Browse Source

define gamma and beta as VariableV1

tags/v0.12
Oceania2018 6 years ago
parent
commit
717f7143a6
5 changed files with 28 additions and 28 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +4
    -4
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  3. +4
    -19
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  5. +15
    -0
      src/TensorFlowNET.Core/tensorflow.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -115,8 +115,8 @@ namespace Tensorflow
public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);

public Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
VariableV1 scale,
VariableV1 offset,
Tensor mean = null,
Tensor variance = null,
float epsilon = 0.001f,


+ 4
- 4
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -37,8 +37,8 @@ namespace Tensorflow.Keras.Layers
private IInitializer gamma_initializer;
private IInitializer moving_mean_initializer;
private IInitializer moving_variance_initializer;
private RefVariable gamma;
private RefVariable beta;
private VariableV1 gamma;
private VariableV1 beta;
private RefVariable moving_mean;
private RefVariable moving_variance;

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
var param_shape = new int[] { input_shape.dims[axis[0]] };

if (scale)
gamma = (RefVariable)add_weight("gamma",
gamma = add_weight("gamma",
param_shape,
dtype: param_dtype,
initializer: gamma_initializer,
@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("add_weight gamma");

if (center)
beta = (RefVariable)add_weight("beta",
beta = add_weight("beta",
param_shape,
dtype: param_dtype,
initializer: beta_initializer,


+ 4
- 19
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -32,36 +32,21 @@ namespace Tensorflow.Keras.Utils
/// <param name="initializer"></param>
/// <param name="trainable"></param>
/// <returns></returns>
public static RefVariable make_variable(string name,
public static VariableV1 make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);

/// <summary>
/// Adds a new variable to the layer.
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="initializer"></param>
/// <param name="trainable"></param>
/// <returns></returns>
public static RefVariable make_variable(string name,
int[] shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = true,
bool use_resource = true)
bool trainable = true)
{
var initializing_from_value = false;
bool use_resource = true;

ops.init_scope();

Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);

var variable_dtype = dtype.as_base_dtype();
var v = tf.Variable(init_val);
var v = tf.VariableV1(init_val);

return v;
}


+ 3
- 3
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -97,9 +97,9 @@ namespace Tensorflow
/// <param name="is_training"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
public static Tensor[] fused_batch_norm(Tensor x,
VariableV1 scale,
VariableV1 offset,
Tensor mean,
Tensor variance,
float epsilon = 0.001f,


+ 15
- 0
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -58,6 +58,21 @@ namespace Tensorflow
dtype: dtype);
}

public VariableV1 VariableV1<T>(T data,
bool trainable = true,
bool validate_shape = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool use_resource = false)
{
return Tensorflow.variable_scope.default_variable_creator(data,
trainable: trainable,
validate_shape: validate_shape,
name: name,
dtype: dtype,
use_resource: use_resource);
}

public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{
return gen_array_ops.placeholder(dtype, shape, name);


Loading…
Cancel
Save