Browse Source

Initializer that generates tensors with constant values with tf.constant_initializer().

tags/v0.12
Oceania2018 6 years ago
parent
commit
6badc99b88
10 changed files with 65 additions and 8 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/APIs/tf.init.cs
  2. +55
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs

+ 2
- 0
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -20,6 +20,8 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
=> new Constant<T>(value, dtype: dtype, verify_shape: verify_shape);
public IInitializer zeros_initializer => new Zeros(); public IInitializer zeros_initializer => new Zeros();
public IInitializer ones_initializer => new Ones(); public IInitializer ones_initializer => new Ones();
public IInitializer glorot_uniform_initializer => new GlorotUniform(); public IInitializer glorot_uniform_initializer => new GlorotUniform();


+ 55
- 0
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -0,0 +1,55 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

namespace Tensorflow.Operations.Initializers
{
public class Constant<T> : IInitializer
{
TF_DataType dtype;
T value;
bool _verify_shape;

public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
{
this.value = value;
this.dtype = dtype;
_verify_shape = verify_shape;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;

if (!verify_shape.HasValue)
verify_shape = _verify_shape;

return constant_op._constant_impl(value, dtype, shape,
name: "Const",
verify_shape: verify_shape.Value,
allow_broadcast: false);
}

public object get_config()
{
return new
{
value,
dtype = dtype.name()
};
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
{ {
public interface IInitializer public interface IInitializer
{ {
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid);
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null);
object get_config(); object get_config();
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype; this.dtype = dtype;
} }


public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype; dtype = this.dtype;


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype; this.dtype = dtype;
} }


public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype; dtype = this.dtype;


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Operations.Initializers


} }


public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
{ {
return random_ops.random_uniform(shape, return random_ops.random_uniform(shape,
minval: minval, minval: minval,


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype; this.dtype = dtype;
} }


public Tensor call(TensorShape shape, TF_DataType dtype)
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
{ {
return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed);
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -45,7 +45,7 @@ namespace Tensorflow.Operations.Initializers
_dtype = dtype; _dtype = dtype;
} }


public Tensor call(TensorShape shape, TF_DataType dtype)
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
{ {
var (fan_in, fan_out) = _compute_fans(shape); var (fan_in, fan_out) = _compute_fans(shape);
if (_mode == "fan_in") if (_mode == "fan_in")


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype; this.dtype = dtype;
} }


public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype; dtype = this.dtype;


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -155,7 +155,7 @@ namespace Tensorflow
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);


public static explicit operator int(TensorShape shape) => shape.size; public static explicit operator int(TensorShape shape) => shape.size;
public static explicit operator TensorShape(int dim) => new TensorShape(dim);
public static implicit operator TensorShape(int dim) => new TensorShape(dim);


public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);


Loading…
Cancel
Save