Browse Source

Fix namespace compile issue.

tags/v0.100.5-BERT-load
Haiping Chen 2 years ago
parent
commit
df913078b9
4 changed files with 46 additions and 4 deletions
  1. +8
    -0
      src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
  2. +36
    -0
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  3. +1
    -2
      test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
  4. +1
    -2
      test/TensorFlowNET.Graph.UnitTest/SignalTest.cs

+ 8
- 0
src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs View File

@@ -10,5 +10,13 @@ public interface IOptimizer
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);

void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);

IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
}

+ 36
- 0
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -78,6 +78,42 @@ namespace Tensorflow.Keras.Optimizers
});
}

public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
=> apply_gradients(new[] { grads_and_vars },
name: name,
experimental_aggregate_gradients: experimental_aggregate_gradients);

/// <summary>
/// Apply gradients to variables.
/// </summary>
/// <param name="grads_and_vars"></param>
/// <param name="name"></param>
/// <param name="experimental_aggregate_gradients"></param>
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
{
var var_list = grads_and_vars.Select(x => x.Item2).ToArray();
tf_with(ops.name_scope(_name), delegate
{
ops.init_scope();
_create_all_weights(var_list);
if (grads_and_vars == null || grads_and_vars.Count() == 0)
return control_flow_ops.no_op();

var apply_state = _prepare(var_list);
// if(experimental_aggregate_gradients)
{
// var reduced_grads = _aggregate_gradients(grads_and_vars);
_distributed_apply(grads_and_vars.Select(x => (x.Item1, (IVariableV1)x.Item2)), name, apply_state);
}

return null;
});
}

void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
_resource_apply_dense(var, grad, apply_state);


+ 1
- 2
test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs View File

@@ -5,8 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;
using TensorFlowNET.Keras.UnitTest;
using Tensorflow.Keras.UnitTest;

namespace TensorFlowNET.UnitTest.Basics
{


+ 1
- 2
test/TensorFlowNET.Graph.UnitTest/SignalTest.cs View File

@@ -5,8 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;
using TensorFlowNET.Keras.UnitTest;
using Tensorflow.Keras.UnitTest;

namespace TensorFlowNET.UnitTest.Basics
{


Loading…
Cancel
Save