@@ -41,6 +41,9 @@ namespace Tensorflow | |||
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") | |||
=> new AdamOptimizer(learning_rate, name: name); | |||
public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam") | |||
=> new AdamOptimizer(learning_rate, name: name, dtype: dtype); | |||
public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") | |||
=> new AdamOptimizer(learning_rate, name: name); | |||
@@ -35,6 +35,25 @@ | |||
DtFloatRef = 101, // DT_FLOAT_REF | |||
DtDoubleRef = 102, // DT_DOUBLE_REF | |||
DtInt32Ref = 103, // DT_INT32_REF | |||
DtInt64Ref = 109 // DT_INT64_REF | |||
DtUint8Ref = 104, | |||
DtInt16Ref = 105, | |||
DtInt8Ref = 106, | |||
DtStringRef = 107, | |||
DtComplex64Ref = 108, | |||
DtInt64Ref = 109, // DT_INT64_REF | |||
DtBoolRef = 110, | |||
DtQint8Ref = 111, | |||
DtQuint8Ref = 112, | |||
DtQint32Ref = 113, | |||
DtBfloat16Ref = 114, | |||
DtQint16Ref = 115, | |||
DtQuint16Ref = 116, | |||
DtUint16Ref = 117, | |||
DtComplex128Ref = 118, | |||
DtHalfRef = 119, | |||
DtResourceRef = 120, | |||
DtVariantRef = 121, | |||
DtUint32Ref = 122, | |||
DtUint64Ref = 123, | |||
} | |||
} |
@@ -46,7 +46,7 @@ namespace Tensorflow | |||
/// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||
public static Type as_numpy_dtype(this TF_DataType type) | |||
{ | |||
switch (type) | |||
switch (type.as_base_dtype()) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
return typeof(bool); | |||
@@ -182,14 +182,12 @@ namespace Tensorflow | |||
public static DataType as_datatype_enum(this TF_DataType type) | |||
{ | |||
return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||
return (DataType)type; | |||
} | |||
public static TF_DataType as_base_dtype(this TF_DataType type) | |||
{ | |||
return (int)type > 100 ? | |||
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) : | |||
type; | |||
return (int)type > 100 ? (TF_DataType)((int)type - 100) : type; | |||
} | |||
public static int name(this TF_DataType type) | |||
@@ -213,21 +211,17 @@ namespace Tensorflow | |||
public static DataType as_base_dtype(this DataType type) | |||
{ | |||
return (int)type > 100 ? | |||
(DataType)Enum.Parse(typeof(DataType), ((int)type - 100).ToString()) : | |||
type; | |||
return (int)type > 100 ? (DataType)((int)type - 100) : type; | |||
} | |||
public static TF_DataType as_tf_dtype(this DataType type) | |||
{ | |||
return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||
return (TF_DataType)type; | |||
} | |||
public static TF_DataType as_ref(this TF_DataType type) | |||
{ | |||
return (int)type < 100 ? | |||
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : | |||
type; | |||
return (int)type < 100 ? (TF_DataType)((int)type + 100) : type; | |||
} | |||
public static long max(this TF_DataType type) | |||
@@ -32,21 +32,24 @@ namespace Tensorflow.Train | |||
float _beta2; | |||
float _epsilon; | |||
Tensor _beta1_t, _beta2_t, _epsilon_t; | |||
TF_DataType _dtype; | |||
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") | |||
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") | |||
: base(learning_rate, use_locking, name) | |||
{ | |||
_beta1 = beta1; | |||
_beta2 = beta2; | |||
_epsilon = epsilon; | |||
_dtype = dtype; | |||
} | |||
public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") | |||
public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam") | |||
: base(learning_rate, use_locking, name) | |||
{ | |||
_beta1 = beta1; | |||
_beta2 = beta2; | |||
_epsilon = epsilon; | |||
_dtype = dtype; | |||
} | |||
public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||
@@ -154,10 +157,10 @@ namespace Tensorflow.Train | |||
var beta2 = _call_if_callable(_beta2); | |||
var epsilon = _call_if_callable(_epsilon); | |||
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); | |||
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); | |||
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); | |||
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); | |||
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate", dtype: _dtype); | |||
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1", dtype: _dtype); | |||
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2", dtype: _dtype); | |||
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon", dtype: _dtype); | |||
} | |||
} | |||
} |
@@ -253,6 +253,7 @@ namespace Tensorflow | |||
v = variable_scope.default_variable_creator( | |||
initial_value, | |||
name: name, | |||
dtype: colocate_with.dtype.as_base_dtype(), | |||
trainable: false, | |||
use_resource: resource_variable_ops.is_resource_variable( | |||
colocate_with)); | |||