diff --git a/src/TensorFlowNET.Keras/Engine/BasePreprocessingLayer.cs b/src/TensorFlowNET.Keras/Engine/BasePreprocessingLayer.cs index 6beda9fe..61c57d39 100644 --- a/src/TensorFlowNET.Keras/Engine/BasePreprocessingLayer.cs +++ b/src/TensorFlowNET.Keras/Engine/BasePreprocessingLayer.cs @@ -1,10 +1,58 @@ -using System; +using Keras.Layers; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Data; +using Tensorflow.Keras.Initializers; namespace Tensorflow.Keras.Engine { - class BasePreprocessingLayer + public abstract class PreprocessingLayer : Layer { + public abstract void adapt(Data.DatasetV1 data, bool reset_state = true); + } + + public abstract class Combiner + { + public abstract dynamic compute(NDArray[] batch_values, dynamic accumulator = null); + + public abstract dynamic merge(dynamic[] accumulators); + + public abstract NDArray[] extract(dynamic accumulator); + + public abstract dynamic restore(Tensor output); + + public abstract string serialize(dynamic accumulator); + + public abstract dynamic deserialize(string encoded_accumulator); + + public override string ToString() + { + throw new NotImplementedException(); + } + } + + public class CombinerPreprocessingLayer : PreprocessingLayer + { + public CombinerPreprocessingLayer(Combiner combiner) + { + throw new NotImplementedException(); + } + + private void _add_state_variable(string name, TensorShape shape, string dtype, Initializer initializer= null, string partitioner= null, bool? use_resource= null) => throw new NotImplementedException(); + + private Dictionary _restore_updates() => throw new NotImplementedException(); + + private bool _dataset_is_infinite(DatasetV1 dataset) => throw new NotImplementedException(); + + private dynamic _get_dataset_iterator(DatasetV1 dataset) => throw new NotImplementedException(); + + private void _set_state_variables(Dictionary updates) => throw new NotImplementedException(); + + public override void adapt(DatasetV1 data, bool reset_state = true) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/CallContext.cs b/src/TensorFlowNET.Keras/Engine/CallContext.cs index 0b3c5ff2..5547c6de 100644 --- a/src/TensorFlowNET.Keras/Engine/CallContext.cs +++ b/src/TensorFlowNET.Keras/Engine/CallContext.cs @@ -1,10 +1,45 @@ -using System; +using Keras.Layers; +using System; using System.Collections.Generic; +using System.Reflection; using System.Text; namespace Tensorflow.Keras.Engine { public class CallContext { + public bool in_keras_graph + { + get + { + throw new NotImplementedException(); + } + } + public CallContext() + { + + } + + public void enter(Layer layer, Tensor[] inputs, Graph build_graph, bool training, Saving saving = null) => throw new NotImplementedException(); + + public bool training_arg_passed_to_call(string[] argspec, Dictionary args, Dictionary kwargs) => throw new NotImplementedException(); + + public dynamic autocast_context_manager(string dtype) => throw new NotImplementedException(); + + public bool is_subclassed(Layer layer) => throw new NotImplementedException(); + + public bool from_saved_model(Layer layer) => throw new NotImplementedException(); + + public bool check_graph_consistency(Tensor tensor = null, string method = "add_loss", bool force_raise = false) => throw new NotImplementedException(); + + public dynamic mark_as_return(Tensor[] outputs, dynamic acd) => throw new NotImplementedException(); + + public MethodInfo Default(MemberInfo method) => throw new NotImplementedException(); + + public void enable_v2_dtype_behavior() => throw new NotImplementedException(); + + public void disable_v2_dtype_behavior() => throw new NotImplementedException(); + + public void v2_dtype_behavior_enabled() => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Keras/Engine/Saving.cs b/src/TensorFlowNET.Keras/Engine/Saving.cs index 43ba2cf6..8ba804c3 100644 --- a/src/TensorFlowNET.Keras/Engine/Saving.cs +++ b/src/TensorFlowNET.Keras/Engine/Saving.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Keras.Engine { - class Saving + public class Saving { } } diff --git a/src/TensorFlowNET.Keras/Engine/TrackableWeightHandler.cs b/src/TensorFlowNET.Keras/Engine/TrackableWeightHandler.cs new file mode 100644 index 00000000..c6305809 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/TrackableWeightHandler.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public class TrackableWeightHandler + { + public int num_tensors + { + get + { + throw new NotImplementedException(); + } + } + + public TrackableWeightHandler(bool trackable) + { + throw new NotImplementedException(); + } + + public void set_weights(Tensor[] weights) => throw new NotImplementedException(); + + public void _set_weights_v1(Tensor[] weights) => throw new NotImplementedException(); + } +}