the bool to tensor has a bug, if in init the training is False, the program not start.tags/v0.110.4-Transformer-Model
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras | |||||
List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); | |||||
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); | |||||
List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
@@ -145,7 +145,7 @@ namespace Tensorflow | |||||
throw new NotImplementedException("_zero_state_tensors"); | throw new NotImplementedException("_zero_state_tensors"); | ||||
} | } | ||||
public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) | |||||
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <param name="training"></param> | /// <param name="training"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) | |||||
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
if (callContext.Value == null) | if (callContext.Value == null) | ||||
callContext.Value = new CallContext(); | callContext.Value = new CallContext(); | ||||
@@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Engine | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
IDatasetV2 validation_data = null, | IDatasetV2 validation_data = null, | ||||
int validation_step = 10, // 间隔多少次会进行一次验证 | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
@@ -164,11 +165,11 @@ namespace Tensorflow.Keras.Engine | |||||
}); | }); | ||||
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||||
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, | |||||
train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
} | } | ||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||||
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | ||||
{ | { | ||||
stop_training = false; | stop_training = false; | ||||
@@ -207,6 +208,9 @@ namespace Tensorflow.Keras.Engine | |||||
if (validation_data != null) | if (validation_data != null) | ||||
{ | { | ||||
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | |||||
continue; | |||||
var val_logs = evaluate(validation_data); | var val_logs = evaluate(validation_data); | ||||
foreach(var log in val_logs) | foreach(var log in val_logs) | ||||
{ | { | ||||
@@ -393,7 +393,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
} | } | ||||
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) | |||||
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | ||||
if (optional_args is not null && rnn_optional_args is null) | if (optional_args is not null && rnn_optional_args is null) | ||||