Browse Source

gen_random_ops.truncated_normal

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
958fdb8791
9 changed files with 84 additions and 6 deletions
  1. +11
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Initializers.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Keras/backend.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  6. +23
    -0
      src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs
  7. +21
    -0
      src/TensorFlowNET.Core/Operations/random_ops.py.cs
  8. +10
    -1
      src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs

+ 11
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -19,6 +19,13 @@ namespace Tensorflow.Keras.Engine
/// </summary>
protected bool built;

protected List<RefVariable> _trainable_weights;

public Layer()
{
_trainable_weights = new List<RefVariable>();
}

public Tensor __call__(Tensor inputs,
VariableScope scope = null)
{
@@ -36,6 +43,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
{
_maybe_build(inputs);
built = true;
}
});

@@ -65,13 +73,15 @@ namespace Tensorflow.Keras.Engine
bool? trainable = null,
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
{
_add_variable_with_custom_getter(name,
var variable = _add_variable_with_custom_getter(name,
shape,
dtype: dtype,
getter: getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
backend.track_variable(variable);
_trainable_weights.Add(variable);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Initializers.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Keras
/// <returns></returns>
public IInitializer he_normal(int? seed = null)
{
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/backend.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public class backend
{
public static void track_variable(RefVariable v)
{

}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Layers
public Layer(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null)
bool? _reuse = null) : base()
{
this.trainable = trainable;
this.stateful = false;


+ 2
- 1
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -43,7 +43,8 @@ namespace Tensorflow.Operations.Initializers

if (_distribution == "normal" || _distribution == "truncated_normal")
{
throw new NotImplementedException("truncated_normal");
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f;
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed);
}
else if (_distribution == "untruncated_normal")
{


+ 23
- 0
src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs View File

@@ -53,5 +53,28 @@ namespace Tensorflow

return _op.outputs[0];
}

/// <summary>
/// Outputs random values from a truncated normal distribution.
/// </summary>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="seed"></param>
/// <param name="seed2"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

var _op = _op_def_lib._apply_op_helper("TruncatedNormal",
name: name,
args: new { shape, dtype, seed, seed2 });

return _op.outputs[0];
}
}
}

+ 21
- 0
src/TensorFlowNET.Core/Operations/random_ops.py.cs View File

@@ -64,6 +64,27 @@ namespace Tensorflow
});
}

public static Tensor truncated_normal(int[] shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)
{
return with(ops.name_scope(name, "truncated_normal", new { shape, mean, stddev }), scope =>
{
name = scope;
var shape_tensor = _ShapeTensor(shape);
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev");
var (seed1, seed2) = random_seed.get_seed(seed);
var rnd = gen_random_ops.truncated_normal(shape_tensor, dtype, seed: seed1, seed2: seed2);
var mul = rnd * stddev_tensor;
var value = math_ops.add(mul, mean_tensor, name: name);
return value;
});
}

private static Tensor _ShapeTensor(int[] shape)
{
return ops.convert_to_tensor(shape, name: "shape");


+ 10
- 1
src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs View File

@@ -19,7 +19,16 @@ namespace Tensorflow
bool trainable = false)
{
var new_variable = getter(name, shape, dtype, initializer, trainable);
throw new NotImplementedException("_add_variable_with_custom_getter");
if (!overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: name,
overwrite: overwrite);
else
return new_variable;
}

protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
{
return checkpointable;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -136,8 +136,8 @@ namespace Tensorflow
{
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else


Loading…
Cancel
Save