Browse Source

compute_gradients

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
4c2f76e2aa
7 changed files with 130 additions and 1 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/Train/Distribute.cs
  2. +13
    -0
      src/TensorFlowNET.Core/Train/GateGradientType.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
  4. +55
    -0
      src/TensorFlowNET.Core/Train/Optimizer.cs
  5. +14
    -0
      src/TensorFlowNET.Core/Train/VariableAggregationType.cs
  6. +17
    -0
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  7. +1
    -1
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 14
- 0
src/TensorFlowNET.Core/Train/Distribute.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static class Distribute
{
public static VariableAggregationType get_loss_reduction()
{
return VariableAggregationType.MEAN;
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Train/GateGradientType.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public enum GateGradientType
{
GATE_NONE = 0,
GATE_OP = 1,
GATE_GRAPH = 2
}
}

+ 16
- 0
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class GradientDescentOptimizer : Optimizer
{
public GradientDescentOptimizer(double learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name)
{
LearningRate = learning_rate;
LearningRateTensor = null;
}
}
}

+ 55
- 0
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Text;
using distribute_lib = Tensorflow.Distribute;

namespace Tensorflow
{
/// <summary>
/// Base class for optimizers.
/// This class defines the API to add Ops to train a model. You never use this
/// class directly, but instead instantiate one of its subclasses such as
/// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
/// </summary>
public abstract class Optimizer
{
public string Name { get; set; }
public double LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }

public Optimizer(double learning_rate, bool use_locking, string name = "")
{
if (String.IsNullOrEmpty(name))
throw new NotImplementedException("Must specify the optimizer name");

Name = name;
}

/// <summary>
/// Add operations to minimize `loss` by updating `var_list`
/// </summary>
/// <param name="loss"></param>
/// <returns></returns>
public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
{
compute_gradients(loss, gate_gradients);
return this;
}

/// <summary>
/// Compute gradients of `loss` for the variables in `var_list`.
/// </summary>
/// <param name="loss"></param>
/// <param name="gate_gradients"></param>
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
{
int num_towers = 1;
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
{
}

return null;
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Train/VariableAggregationType.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public enum VariableAggregationType
{
NONE = 0,
SUM = 1,
MEAN = 2,
ONLY_FIRST_TOWER = 3
}
}

+ 17
- 0
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static class train
{
public static Optimizer GradientDescentOptimizer(double learning_rate)
{
return new GradientDescentOptimizer(learning_rate);
}
}
}
}

+ 1
- 1
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -47,7 +47,7 @@ namespace TensorFlowNET.Examples

// radient descent
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
// var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
}
}
}

Loading…
Cancel
Save