Oceania2018 5 years ago
parent
commit
c7a59e01d0
5 changed files with 39 additions and 19 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +20
    -1
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  3. +6
    -12
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +9
    -6
      src/TensorFlowNET.Core/Training/AdamOptimizer.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Training/Optimizer.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -41,6 +41,9 @@ namespace Tensorflow
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);

public Optimizer AdamOptimizer(float learning_rate, TF_DataType dtype, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name, dtype: dtype);

public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);



+ 20
- 1
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -35,6 +35,25 @@
DtFloatRef = 101, // DT_FLOAT_REF
DtDoubleRef = 102, // DT_DOUBLE_REF
DtInt32Ref = 103, // DT_INT32_REF
DtInt64Ref = 109 // DT_INT64_REF
DtUint8Ref = 104,
DtInt16Ref = 105,
DtInt8Ref = 106,
DtStringRef = 107,
DtComplex64Ref = 108,
DtInt64Ref = 109, // DT_INT64_REF
DtBoolRef = 110,
DtQint8Ref = 111,
DtQuint8Ref = 112,
DtQint32Ref = 113,
DtBfloat16Ref = 114,
DtQint16Ref = 115,
DtQuint16Ref = 116,
DtUint16Ref = 117,
DtComplex128Ref = 118,
DtHalfRef = 119,
DtResourceRef = 120,
DtVariantRef = 121,
DtUint32Ref = 122,
DtUint64Ref = 123,
}
}

+ 6
- 12
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow
/// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns>
public static Type as_numpy_dtype(this TF_DataType type)
{
switch (type)
switch (type.as_base_dtype())
{
case TF_DataType.TF_BOOL:
return typeof(bool);
@@ -182,14 +182,12 @@ namespace Tensorflow

public static DataType as_datatype_enum(this TF_DataType type)
{
return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid;
return (DataType)type;
}

public static TF_DataType as_base_dtype(this TF_DataType type)
{
return (int)type > 100 ?
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) :
type;
return (int)type > 100 ? (TF_DataType)((int)type - 100) : type;
}

public static int name(this TF_DataType type)
@@ -213,21 +211,17 @@ namespace Tensorflow

public static DataType as_base_dtype(this DataType type)
{
return (int)type > 100 ?
(DataType)Enum.Parse(typeof(DataType), ((int)type - 100).ToString()) :
type;
return (int)type > 100 ? (DataType)((int)type - 100) : type;
}

public static TF_DataType as_tf_dtype(this DataType type)
{
return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid;
return (TF_DataType)type;
}

public static TF_DataType as_ref(this TF_DataType type)
{
return (int)type < 100 ?
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) :
type;
return (int)type < 100 ? (TF_DataType)((int)type + 100) : type;
}

public static long max(this TF_DataType type)


+ 9
- 6
src/TensorFlowNET.Core/Training/AdamOptimizer.cs View File

@@ -32,21 +32,24 @@ namespace Tensorflow.Train
float _beta2;
float _epsilon;
Tensor _beta1_t, _beta2_t, _epsilon_t;
TF_DataType _dtype;

public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam")
: base(learning_rate, use_locking, name)
{
_beta1 = beta1;
_beta2 = beta2;
_epsilon = epsilon;
_dtype = dtype;
}

public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "Adam")
: base(learning_rate, use_locking, name)
{
_beta1 = beta1;
_beta2 = beta2;
_epsilon = epsilon;
_dtype = dtype;
}

public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
@@ -154,10 +157,10 @@ namespace Tensorflow.Train
var beta2 = _call_if_callable(_beta2);
var epsilon = _call_if_callable(_epsilon);

_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate");
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1");
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2");
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon");
_lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate", dtype: _dtype);
_beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1", dtype: _dtype);
_beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2", dtype: _dtype);
_epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon", dtype: _dtype);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -253,6 +253,7 @@ namespace Tensorflow
v = variable_scope.default_variable_creator(
initial_value,
name: name,
dtype: colocate_with.dtype.as_base_dtype(),
trainable: false,
use_resource: resource_variable_ops.is_resource_variable(
colocate_with));


Loading…
Cancel
Save