Browse Source

fixed #632

tags/v0.30
Oceania2018 4 years ago
parent
commit
ccc52dd98d
3 changed files with 20 additions and 3 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs
  2. +4
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  3. +15
    -0
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 1
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Engine
var data = iterator.next(); var data = iterator.next();
var outputs = train_step(data[0], data[1]); var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
throw new NotImplementedException("");
return null;
} }


/// <summary> /// <summary>
@@ -38,7 +38,6 @@ namespace Tensorflow.Keras.Engine
// The _minimize call does a few extra steps unnecessary in most cases, // The _minimize call does a few extra steps unnecessary in most cases,
// such as loss scaling and gradient clipping. // such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables); _minimize(tape, optimizer, loss, trainable_variables);

compiled_metrics.update_state(y, y_pred); compiled_metrics.update_state(y, y_pred);
return new[] { ("loss", loss) }; return new[] { ("loss", loss) };
} }


+ 4
- 1
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers
return control_flow_ops.no_op(); return control_flow_ops.no_op();


var apply_state = _prepare(var_list); var apply_state = _prepare(var_list);
if(experimental_aggregate_gradients)
// if(experimental_aggregate_gradients)
{ {
// var reduced_grads = _aggregate_gradients(grads_and_vars); // var reduced_grads = _aggregate_gradients(grads_and_vars);
_distributed_apply(grads_and_vars, name, apply_state); _distributed_apply(grads_and_vars, name, apply_state);
@@ -84,6 +84,9 @@ namespace Tensorflow.Keras.Optimizers
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{ {
_resource_apply_dense(var, grad, apply_state); _resource_apply_dense(var, grad, apply_state);
// if var.constraint is not None:
// with ops.control_dependencies([update_op]):
// return var.assign(var.constraint(var))
} }


protected virtual Operation _resource_apply_dense(IVariableV1 var, protected virtual Operation _resource_apply_dense(IVariableV1 var,


+ 15
- 0
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -67,6 +67,21 @@ namespace Tensorflow
name: name); name: name);
} }


public static Tensor assign(IVariableV1 @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
else
return @ref.assign(value, name: name);
}

public static Tensor assign_sub(IVariableV1 @ref, public static Tensor assign_sub(IVariableV1 @ref,
Tensor value, Tensor value,
bool use_locking = false, bool use_locking = false,


Loading…
Cancel
Save