@@ -20,6 +20,8 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
public IInitializer constant_initializer<T>(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | |||||
=> new Constant<T>(value, dtype: dtype, verify_shape: verify_shape); | |||||
public IInitializer zeros_initializer => new Zeros(); | public IInitializer zeros_initializer => new Zeros(); | ||||
public IInitializer ones_initializer => new Ones(); | public IInitializer ones_initializer => new Ones(); | ||||
public IInitializer glorot_uniform_initializer => new GlorotUniform(); | public IInitializer glorot_uniform_initializer => new GlorotUniform(); | ||||
@@ -0,0 +1,55 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
public class Constant<T> : IInitializer | |||||
{ | |||||
TF_DataType dtype; | |||||
T value; | |||||
bool _verify_shape; | |||||
public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | |||||
{ | |||||
this.value = value; | |||||
this.dtype = dtype; | |||||
_verify_shape = verify_shape; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
{ | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = this.dtype; | |||||
if (!verify_shape.HasValue) | |||||
verify_shape = _verify_shape; | |||||
return constant_op._constant_impl(value, dtype, shape, | |||||
name: "Const", | |||||
verify_shape: verify_shape.Value, | |||||
allow_broadcast: false); | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
value, | |||||
dtype = dtype.name() | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -18,7 +18,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public interface IInitializer | public interface IInitializer | ||||
{ | { | ||||
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); | |||||
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null); | |||||
object get_config(); | object get_config(); | ||||
} | } | ||||
} | } |
@@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers | |||||
this.dtype = dtype; | this.dtype = dtype; | ||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
{ | { | ||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = this.dtype; | dtype = this.dtype; | ||||
@@ -38,7 +38,7 @@ namespace Tensorflow.Operations.Initializers | |||||
this.dtype = dtype; | this.dtype = dtype; | ||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
{ | { | ||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = this.dtype; | dtype = this.dtype; | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow.Operations.Initializers | |||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
{ | { | ||||
return random_ops.random_uniform(shape, | return random_ops.random_uniform(shape, | ||||
minval: minval, | minval: minval, | ||||
@@ -34,7 +34,7 @@ namespace Tensorflow.Operations.Initializers | |||||
this.dtype = dtype; | this.dtype = dtype; | ||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||||
{ | { | ||||
return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); | return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); | ||||
} | } | ||||
@@ -45,7 +45,7 @@ namespace Tensorflow.Operations.Initializers | |||||
_dtype = dtype; | _dtype = dtype; | ||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||||
{ | { | ||||
var (fan_in, fan_out) = _compute_fans(shape); | var (fan_in, fan_out) = _compute_fans(shape); | ||||
if (_mode == "fan_in") | if (_mode == "fan_in") | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers | |||||
this.dtype = dtype; | this.dtype = dtype; | ||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) | |||||
{ | { | ||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = this.dtype; | dtype = this.dtype; | ||||
@@ -155,7 +155,7 @@ namespace Tensorflow | |||||
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
public static explicit operator int(TensorShape shape) => shape.size; | public static explicit operator int(TensorShape shape) => shape.size; | ||||
public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
public static implicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | ||||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||