Browse Source

gradients_impl. not finish yet.

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
c6f9ec6179
7 changed files with 109 additions and 5 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/Gradients/AggregationMethod.cs
  2. +32
    -0
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +10
    -4
      src/TensorFlowNET.Core/Train/Optimizer.cs
  5. +12
    -0
      src/TensorFlowNET.Core/Train/_OptimizableVariable.cs
  6. +34
    -0
      src/TensorFlowNET.Core/Train/optimizer.py.cs
  7. +5
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs

+ 15
- 0
src/TensorFlowNET.Core/Gradients/AggregationMethod.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class AggregationMethod
{
public static int ADD_N = 0;
public static int DEFAULT = ADD_N;
// The following are experimental and may not be supported in future releases.
public static int EXPERIMENTAL_TREE = 1;
public static int EXPERIMENTAL_ACCUMULATE_N = 2;
}
}

+ 32
- 0
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gradients_impl
{
public static void gradients(object ys,
object xs,
List<Tensor> grad_ys = null,
string name = "gradients",
bool colocate_gradients_with_ops = false,
bool gate_gradients = false,
int? aggregation_method = null)
{
_GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients);
}

public static void _GradientsHelper(object ys,
object xs,
List<Tensor> grad_ys = null,
string name = "gradients",
bool colocate_gradients_with_ops = false,
bool gate_gradients = false,
Graph src_graph = null)
{
if (src_graph == null)
src_graph = ops.get_default_graph();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -222,7 +222,7 @@ namespace Tensorflow
}
}

return "";
return $"{name} {dtype} {rank} {string.Join(",", shape)}";
}

public void Dispose()


+ 10
- 4
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using distribute_lib = Tensorflow.Distribute;

@@ -48,8 +49,10 @@ namespace Tensorflow
/// <param name="gate_gradients"></param>
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
List<RefVariable> var_list = null,
int? aggregation_method = null,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false)
bool colocate_gradients_with_ops = false,
List<Tensor> grad_loss = null)
{
int num_towers = 1;
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
@@ -65,10 +68,13 @@ namespace Tensorflow
break;
}

foreach(var v in var_list)
{
var processors = var_list.Select(v => optimizer._get_processor(v));
var var_refs = processors.Select(x => x.target()).ToList();

}
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,
gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
aggregation_method: aggregation_method,
colocate_gradients_with_ops: colocate_gradients_with_ops);

return null;
}


+ 12
- 0
src/TensorFlowNET.Core/Train/_OptimizableVariable.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface _OptimizableVariable
{
Tensor target();
void update_op(Graph g);
}
}

+ 34
- 0
src/TensorFlowNET.Core/Train/optimizer.py.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class optimizer
{
public static _OptimizableVariable _get_processor(RefVariable v)
{
return new _RefVariableProcessor(v);
}
}

public class _RefVariableProcessor : _OptimizableVariable
{
private RefVariable _v;

public _RefVariableProcessor(RefVariable v)
{
_v = v;
}

public Tensor target()
{
return _v._ref();
}

public void update_op(Graph g)
{
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -78,6 +78,11 @@ namespace Tensorflow
ops.add_to_collections(collections, this);
}

public Tensor _ref()
{
return _variable;
}

public static implicit operator _VariableScopeStore(RefVariable variable)
{
return null;


Loading…
Cancel
Save