Browse Source

make `Activation` as the implementation of `IActivation` and use `IActivation` in `ActivationAdapter` to receive more generally `activation`.

pull/1085/head
lingbai-kong 2 years ago
parent
commit
6b473e4a40
2 changed files with 27 additions and 12 deletions
  1. +24
    -9
      src/TensorFlowNET.Core/Keras/Activations/Activations.cs
  2. +3
    -3
      src/TensorFlowNET.Keras/Activations.cs

+ 24
- 9
src/TensorFlowNET.Core/Keras/Activations/Activations.cs View File

@@ -3,11 +3,12 @@ using System;
using System.Reflection; using System.Reflection;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using Tensorflow.Keras.Saving.Common; using Tensorflow.Keras.Saving.Common;
using Tensorflow.Operations.Activation;


namespace Tensorflow.Keras namespace Tensorflow.Keras
{ {
[JsonConverter(typeof(CustomizedActivationJsonConverter))] [JsonConverter(typeof(CustomizedActivationJsonConverter))]
public class Activation
public class Activation : IActivation
{ {
public string Name { get; set; } public string Name { get; set; }
/// <summary> /// <summary>
@@ -15,7 +16,21 @@ namespace Tensorflow.Keras
/// </summary> /// </summary>
public Func<Tensor, string, Tensor> ActivationFunction { get; set; } public Func<Tensor, string, Tensor> ActivationFunction { get; set; }


public Tensor Apply(Tensor input, string name = null) => ActivationFunction(input, name);
/// <summary>
/// The implementation function of `IActivation`
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor Activate(Tensor x, string name = null) => ActivationFunction(x, name);

/// <summary>
/// The function for calling in LayersApi, an alias for `Activate`.
/// </summary>
/// <param name="input"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor Apply(Tensor input, string name = null) => Activate(input, name);


public static implicit operator Activation(Func<Tensor, string, Tensor> func) public static implicit operator Activation(Func<Tensor, string, Tensor> func)
{ {
@@ -28,23 +43,23 @@ namespace Tensorflow.Keras
} }


/// <summary> /// <summary>
/// The ActivationAdapter is used to store string, Activation, and Func for Laysers Api to accept different types of activation parameters.
/// The `ActivationAdapter` is used to store the string, the `IActivation` implementation class, and the `Func` for LayersApi to accept different types of activation parameters.
/// One of the properties must be specified while initializing. /// One of the properties must be specified while initializing.
/// </summary> /// </summary>
public class ActivationAdapter public class ActivationAdapter
{ {
/// <summary> /// <summary>
/// The name of activaiton function, such as `tanh`, `sigmoid`.
/// The name of the activaiton function, such as "tanh", "sigmoid".
/// </summary> /// </summary>
public string? Name { get; set; } = null; public string? Name { get; set; } = null;


/// <summary> /// <summary>
/// The available Activation instance of activaiton function, such as keras.activations.Tanh, keras.activations.Sigmoid.
/// The available `IActivation` implementation class of the activaiton function, such as the `Activation` instances (keras.activations.Tanh, keras.activations.Sigmoid) and other `IActivation` implementation class.
/// </summary> /// </summary>
public Activation? Activation { get; set; } = null;
public IActivation? Activation { get; set; } = null;


/// <summary> /// <summary>
/// The Func definition of activation function, which can be customized.
/// The `Func` definition of the activation function, which can be customized.
/// </summary> /// </summary>
public Func<Tensor, string, Tensor>? Func { get; set; } = null; public Func<Tensor, string, Tensor>? Func { get; set; } = null;


@@ -53,7 +68,7 @@ namespace Tensorflow.Keras
Name = name; Name = name;
} }


public ActivationAdapter(Activation activation)
public ActivationAdapter(IActivation activation)
{ {
Activation = activation; Activation = activation;
} }
@@ -83,7 +98,7 @@ namespace Tensorflow.Keras
public interface IActivationsApi public interface IActivationsApi
{ {
Activation GetActivationFromName(string name); Activation GetActivationFromName(string name);
Activation GetActivationFromAdapter(ActivationAdapter adapter); Activation GetActivationFromAdapter(ActivationAdapter adapter);


Activation Linear { get; } Activation Linear { get; }


+ 3
- 3
src/TensorFlowNET.Keras/Activations.cs View File

@@ -94,8 +94,8 @@ namespace Tensorflow.Keras
} }


/// <summary> /// <summary>
/// Convert ActivationAdapter to Activation.
/// If more than one properties of ActivationAdapter are specified, the order of priority is `Name`, `Activation`, `Func`
/// Convert `ActivationAdapter` to `Activation`.
/// If more than one properties of `ActivationAdapter` are specified, the order of priority is `Name`, `Activation`, `Func`
/// </summary> /// </summary>
/// <param name="adapter"></param> /// <param name="adapter"></param>
/// <returns></returns> /// <returns></returns>
@@ -112,7 +112,7 @@ namespace Tensorflow.Keras
} }
else if(adapter.Activation != null) else if(adapter.Activation != null)
{ {
return adapter.Activation;
return (Activation) adapter.Activation;
} }
else if(adapter.Func != null) else if(adapter.Func != null)
{ {


Loading…
Cancel
Save