diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index c9294653..6ab916b3 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -20,6 +20,8 @@ namespace Tensorflow { public partial class tensorflow { + public IInitializer constant_initializer(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + => new Constant(value, dtype: dtype, verify_shape: verify_shape); public IInitializer zeros_initializer => new Zeros(); public IInitializer ones_initializer => new Ones(); public IInitializer glorot_uniform_initializer => new GlorotUniform(); diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs new file mode 100644 index 00000000..708d9db6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -0,0 +1,55 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Operations.Initializers +{ + public class Constant : IInitializer + { + TF_DataType dtype; + T value; + bool _verify_shape; + + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + { + this.value = value; + this.dtype = dtype; + _verify_shape = verify_shape; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + + if (!verify_shape.HasValue) + verify_shape = _verify_shape; + + return constant_op._constant_impl(value, dtype, shape, + name: "Const", + verify_shape: verify_shape.Value, + allow_broadcast: false); + } + + public object get_config() + { + return new + { + value, + dtype = dtype.name() + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 4ca6c140..0ac0865f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -18,7 +18,7 @@ namespace Tensorflow { public interface IInitializer { - Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); + Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null); object get_config(); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 6fb4feb6..83e5b57d 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index f553d45b..a3e2063f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -38,7 +38,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index 98595edc..59333c84 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Operations.Initializers } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { return random_ops.random_uniform(shape, minval: minval, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 0611c0e9..7d635f0c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype) + public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) { return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index e88033b5..e2b2a0d6 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -45,7 +45,7 @@ namespace Tensorflow.Operations.Initializers _dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype) + public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) { var (fan_in, fan_out) = _compute_fans(shape); if (_mode == "fan_in") diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index b9d4f746..bea9cf71 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 8c9e571e..8c2f543e 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -155,7 +155,7 @@ namespace Tensorflow public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static explicit operator int(TensorShape shape) => shape.size; - public static explicit operator TensorShape(int dim) => new TensorShape(dim); + public static implicit operator TensorShape(int dim) => new TensorShape(dim); public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);