|
|
@@ -1,6 +1,8 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Train |
|
|
|
{ |
|
|
@@ -11,6 +13,7 @@ namespace Tensorflow.Train |
|
|
|
bool _zero_debias; |
|
|
|
string _name; |
|
|
|
public string name => _name; |
|
|
|
List<VariableV1> _averages; |
|
|
|
|
|
|
|
public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, |
|
|
|
string name = "ExponentialMovingAverage") |
|
|
@@ -19,6 +22,7 @@ namespace Tensorflow.Train |
|
|
|
_num_updates = num_updates; |
|
|
|
_zero_debias = zero_debias; |
|
|
|
_name = name; |
|
|
|
_averages = new List<VariableV1>(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
@@ -26,11 +30,23 @@ namespace Tensorflow.Train |
|
|
|
/// </summary> |
|
|
|
/// <param name="var_list"></param> |
|
|
|
/// <returns></returns> |
|
|
|
public Operation apply(VariableV1[] var_list = null) |
|
|
|
public Operation apply(RefVariable[] var_list = null) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
if (var_list == null) |
|
|
|
var_list = variables.trainable_variables() as RefVariable[]; |
|
|
|
|
|
|
|
foreach(var var in var_list) |
|
|
|
{ |
|
|
|
if (!_averages.Contains(var)) |
|
|
|
{ |
|
|
|
ops.init_scope(); |
|
|
|
var slot = new SlotCreator(); |
|
|
|
var.initialized_value(); |
|
|
|
// var avg = slot.create_zeros_slot |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
} |
|
|
|
} |