@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Engine | |||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var outputs = train_step(data[0], data[1]); | var outputs = train_step(data[0], data[1]); | ||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
throw new NotImplementedException(""); | |||||
return null; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -38,7 +38,6 @@ namespace Tensorflow.Keras.Engine | |||||
// The _minimize call does a few extra steps unnecessary in most cases, | // The _minimize call does a few extra steps unnecessary in most cases, | ||||
// such as loss scaling and gradient clipping. | // such as loss scaling and gradient clipping. | ||||
_minimize(tape, optimizer, loss, trainable_variables); | _minimize(tape, optimizer, loss, trainable_variables); | ||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return new[] { ("loss", loss) }; | return new[] { ("loss", loss) }; | ||||
} | } | ||||
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
return control_flow_ops.no_op(); | return control_flow_ops.no_op(); | ||||
var apply_state = _prepare(var_list); | var apply_state = _prepare(var_list); | ||||
if(experimental_aggregate_gradients) | |||||
// if(experimental_aggregate_gradients) | |||||
{ | { | ||||
// var reduced_grads = _aggregate_gradients(grads_and_vars); | // var reduced_grads = _aggregate_gradients(grads_and_vars); | ||||
_distributed_apply(grads_and_vars, name, apply_state); | _distributed_apply(grads_and_vars, name, apply_state); | ||||
@@ -84,6 +84,9 @@ namespace Tensorflow.Keras.Optimizers | |||||
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | ||||
{ | { | ||||
_resource_apply_dense(var, grad, apply_state); | _resource_apply_dense(var, grad, apply_state); | ||||
// if var.constraint is not None: | |||||
// with ops.control_dependencies([update_op]): | |||||
// return var.assign(var.constraint(var)) | |||||
} | } | ||||
protected virtual Operation _resource_apply_dense(IVariableV1 var, | protected virtual Operation _resource_apply_dense(IVariableV1 var, | ||||
@@ -67,6 +67,21 @@ namespace Tensorflow | |||||
name: name); | name: name); | ||||
} | } | ||||
public static Tensor assign(IVariableV1 @ref, object value, | |||||
bool validate_shape = true, | |||||
bool use_locking = true, | |||||
string name = null) | |||||
{ | |||||
if (@ref.dtype.is_ref_dtype()) | |||||
return gen_state_ops.assign(@ref, | |||||
value, | |||||
validate_shape: validate_shape, | |||||
use_locking: use_locking, | |||||
name: name); | |||||
else | |||||
return @ref.assign(value, name: name); | |||||
} | |||||
public static Tensor assign_sub(IVariableV1 @ref, | public static Tensor assign_sub(IVariableV1 @ref, | ||||
Tensor value, | Tensor value, | ||||
bool use_locking = false, | bool use_locking = false, | ||||