Browse Source

fixed #632

tags/v0.30
Oceania2018 4 years ago
parent
commit
ccc52dd98d
3 changed files with 20 additions and 3 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs
  2. +4
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  3. +15
    -0
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 1
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Engine
var data = iterator.next();
var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
throw new NotImplementedException("");
return null;
}

/// <summary>
@@ -38,7 +38,6 @@ namespace Tensorflow.Keras.Engine
// The _minimize call does a few extra steps unnecessary in most cases,
// such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables);

compiled_metrics.update_state(y, y_pred);
return new[] { ("loss", loss) };
}


+ 4
- 1
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers
return control_flow_ops.no_op();

var apply_state = _prepare(var_list);
if(experimental_aggregate_gradients)
// if(experimental_aggregate_gradients)
{
// var reduced_grads = _aggregate_gradients(grads_and_vars);
_distributed_apply(grads_and_vars, name, apply_state);
@@ -84,6 +84,9 @@ namespace Tensorflow.Keras.Optimizers
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
_resource_apply_dense(var, grad, apply_state);
// if var.constraint is not None:
// with ops.control_dependencies([update_op]):
// return var.assign(var.constraint(var))
}

protected virtual Operation _resource_apply_dense(IVariableV1 var,


+ 15
- 0
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -67,6 +67,21 @@ namespace Tensorflow
name: name);
}

public static Tensor assign(IVariableV1 @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
else
return @ref.assign(value, name: name);
}

public static Tensor assign_sub(IVariableV1 @ref,
Tensor value,
bool use_locking = false,


Loading…
Cancel
Save