diff --git a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj index 6cc631f4..89884977 100644 --- a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj +++ b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1 + net5.0 Tensorflow Tensorflow diff --git a/src/TensorFlowNET.Core/Keras/Engine/INode.cs b/src/TensorFlowNET.Core/Keras/Engine/INode.cs index 1305a772..dde0f8ea 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/INode.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/INode.cs @@ -1,4 +1,6 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { @@ -10,5 +12,7 @@ namespace Tensorflow.Keras.Engine List KerasInputs { get; set; } INode[] ParentNodes { get; } IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); + bool is_input { get; } + NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index c096458f..d6bbf11a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras @@ -14,5 +15,6 @@ namespace Tensorflow.Keras List trainable_variables { get; } TensorShape output_shape { get; } int count_params(); + LayerArgs get_config(); } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs new file mode 100644 index 00000000..950b3132 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving +{ + public class LayerConfig + { + public string Name { get; set; } + public string ClassName { get; set; } + public LayerArgs Config { get; set; } + public List InboundNodes { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs new file mode 100644 index 00000000..fa965aa4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving +{ + public class ModelConfig + { + public string Name { get; set; } + public List Layers { get; set; } + public List InputLayers { get; set; } + public List OutputLayers { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs new file mode 100644 index 00000000..732d9d4d --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public class NodeConfig + { + public string Name { get; set; } + public int NodeIndex { get; set; } + public int TensorIndex { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs new file mode 100644 index 00000000..4437ba0a --- /dev/null +++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; + +namespace Tensorflow.ModelSaving +{ + public class ModelSaver + { + public void save(Trackable obj, string export_dir, SaveOptions options = null) + { + var saved_model = new SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + _build_meta_graph(obj, export_dir, options, meta_graph_def); + } + + void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options, + MetaGraphDef meta_graph_def = null) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs new file mode 100644 index 00000000..e25537d8 --- /dev/null +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.ModelSaving +{ + /// + /// Options for saving to SavedModel. + /// + public class SaveOptions + { + bool save_debug_info; + public SaveOptions(bool save_debug_info = false) + { + this.save_debug_info = save_debug_info; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 3c717825..b2e1566b 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Operations; using Tensorflow.Util; @@ -132,5 +133,10 @@ namespace Tensorflow { throw new NotImplementedException(); } + + public LayerArgs get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Protobuf/Gen.bat b/src/TensorFlowNET.Core/Protobuf/Gen.bat index 745235af..c6256737 100644 --- a/src/TensorFlowNET.Core/Protobuf/Gen.bat +++ b/src/TensorFlowNET.Core/Protobuf/Gen.bat @@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto diff --git a/src/TensorFlowNET.Core/Protobuf/SavedModel.cs b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs new file mode 100644 index 00000000..e7b9259a --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs @@ -0,0 +1,210 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/saved_model.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto + public static partial class SavedModelReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/saved_model.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SavedModelReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S", + "CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy", + "YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h", + "X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm", + "bG93Lk1ldGFHcmFwaERlZkJ7ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", + "EFNhdmVkTW9kZWxQcm90b3NQAVpIZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", + "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2NvcmVfcHJvdG9zX2dvX3By", + "b3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// SavedModel is the high level serialization format for TensorFlow Models. + /// See [todo: doc links, similar to session_bundle] for more information. + /// + public sealed partial class SavedModel : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedModel()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedModel() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedModel(SavedModel other) : this() { + savedModelSchemaVersion_ = other.savedModelSchemaVersion_; + metaGraphs_ = other.metaGraphs_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedModel Clone() { + return new SavedModel(this); + } + + /// Field number for the "saved_model_schema_version" field. + public const int SavedModelSchemaVersionFieldNumber = 1; + private long savedModelSchemaVersion_; + /// + /// The schema version of the SavedModel instance. Used for versioning when + /// making future changes to the specification/implementation. Initial value + /// at release will be 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long SavedModelSchemaVersion { + get { return savedModelSchemaVersion_; } + set { + savedModelSchemaVersion_ = value; + } + } + + /// Field number for the "meta_graphs" field. + public const int MetaGraphsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_metaGraphs_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser); + private readonly pbc::RepeatedField metaGraphs_ = new pbc::RepeatedField(); + /// + /// One or more MetaGraphs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MetaGraphs { + get { return metaGraphs_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SavedModel); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SavedModel other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false; + if(!metaGraphs_.Equals(other.metaGraphs_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode(); + hash ^= metaGraphs_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (SavedModelSchemaVersion != 0L) { + output.WriteRawTag(8); + output.WriteInt64(SavedModelSchemaVersion); + } + metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (SavedModelSchemaVersion != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion); + } + size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SavedModel other) { + if (other == null) { + return; + } + if (other.SavedModelSchemaVersion != 0L) { + SavedModelSchemaVersion = other.SavedModelSchemaVersion; + } + metaGraphs_.Add(other.metaGraphs_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + SavedModelSchemaVersion = input.ReadInt64(); + break; + } + case 18: { + metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs new file mode 100644 index 00000000..96ea11b6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -0,0 +1,76 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + public ModelConfig get_config() + { + return get_network_config(); + } + + /// + /// Builds the config, which consists of the node graph and serialized layers. + /// + ModelConfig get_network_config() + { + var config = new ModelConfig + { + Name = name + }; + + var node_conversion_map = new Dictionary(); + foreach (var layer in _layers) + { + var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + { + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key)) + { + node_conversion_map[node_key] = kept_nodes; + kept_nodes += 1; + } + } + } + + var layer_configs = new List(); + foreach (var layer in _layers) + { + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + { + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + throw new NotImplementedException(""); + } + } + + var layer_config = generic_utils.serialize_keras_object(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } + config.Layers = layer_configs; + + return config; + } + + bool _should_skip_first_node(ILayer layer) + { + return layer is Functional && layer.Layers[0] is InputLayer; + } + + string _make_node_key(string layer_name, int node_index) + => $"{layer_name}_ib-{node_index}"; + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 68027b07..d55b6781 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine /// /// A `Functional` model is a `Model` defined as a directed graph of layers. /// - public class Functional : Model + public partial class Functional : Model { TensorShape _build_input_shape; bool _compute_output_and_mask_jointly; @@ -283,7 +283,7 @@ namespace Tensorflow.Keras.Engine // Propagate to all previous tensors connected to this node. nodes_in_progress.Add(node); - if (!node.IsInput) + if (!node.is_input) { foreach (var k_tensor in node.KerasInputs) { @@ -307,7 +307,7 @@ namespace Tensorflow.Keras.Engine Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) { - if (mask != null) + if (mask == null) { Tensor[] masks = new Tensor[inputs.Count()]; foreach (var (i, input_t) in enumerate(inputs)) @@ -330,7 +330,7 @@ namespace Tensorflow.Keras.Engine foreach (Node node in nodes) { // Input tensors already exist. - if (node.IsInput) + if (node.is_input) continue; var layer_inputs = node.MapArguments(tensor_dict); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 0c892535..93c9f91f 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; @@ -241,5 +242,8 @@ namespace Tensorflow.Keras.Engine return weights; } } + + public virtual LayerArgs get_config() + => throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 32159fb0..a72e67fd 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.Metrics; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Engine { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 58f2be1a..c287309d 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,13 +1,26 @@ using System.Collections.Generic; using Tensorflow.Keras.Metrics; +using Tensorflow.ModelSaving; namespace Tensorflow.Keras.Engine { public partial class Model { - public void save(string path) + ModelSaver saver = new ModelSaver(); + + /// + /// Saves the model to Tensorflow SavedModel or a single HDF5 file. + /// + /// + /// + /// + public void save(string filepath, + bool overwrite = true, + bool include_optimizer = true, + string save_format = "tf", + SaveOptions options = null) { - + saver.save(this, filepath); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs new file mode 100644 index 00000000..05d544f8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Engine +{ + public partial class Node + { + /// + /// Serializes `Node` for Functional API's `get_config`. + /// + /// + public NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Node.cs b/src/TensorFlowNET.Keras/Engine/Node.cs index 1601a6a6..d78e5533 100644 --- a/src/TensorFlowNET.Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Keras/Engine/Node.cs @@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Engine public TensorShape[] output_shapes; public List KerasInputs { get; set; } = new List(); public ILayer Layer { get; set; } - public bool IsInput => args.InputTensors == null; + public bool is_input => args.InputTensors == null; public int[] FlatInputIds { get; set; } public int[] FlatOutputIds { get; set; } bool _single_positional_tensor_passed => KerasInputs.Count() == 1; diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 5f262083..50974cf7 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -17,7 +17,7 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Engine { diff --git a/src/TensorFlowNET.Keras/KerasApi.cs b/src/TensorFlowNET.Keras/KerasApi.cs index 97d28943..a22c0399 100644 --- a/src/TensorFlowNET.Keras/KerasApi.cs +++ b/src/TensorFlowNET.Keras/KerasApi.cs @@ -1,84 +1,12 @@ -using System.Collections.Generic; -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Datasets; -using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers; -using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Metrics; -using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras; -namespace Tensorflow.Keras +namespace Tensorflow { - public class KerasApi + public static class KerasApi { - public KerasDataset datasets { get; } = new KerasDataset(); - public Initializers initializers { get; } = new Initializers(); - public Regularizers regularizers { get; } = new Regularizers(); - public LayersApi layers { get; } = new LayersApi(); - public LossesApi losses { get; } = new LossesApi(); - public Activations activations { get; } = new Activations(); - public Preprocessing preprocessing { get; } = new Preprocessing(); - public BackendImpl backend { get; } = new BackendImpl(); - public OptimizerApi optimizers { get; } = new OptimizerApi(); - public MetricsApi metrics { get; } = new MetricsApi(); + public static KerasInterface Keras(this tensorflow tf) + => new KerasInterface(); - public Sequential Sequential(List layers = null, - string name = null) - => new Sequential(new SequentialArgs - { - Layers = layers, - Name = name - }); - - /// - /// `Model` groups layers into an object with training and inference features. - /// - /// - /// - /// - public Functional Model(Tensors inputs, Tensors outputs, string name = null) - => new Functional(inputs, outputs, name: name); - - /// - /// Instantiate a Keras tensor. - /// - /// - /// - /// - /// - /// - /// A boolean specifying whether the placeholder to be created is sparse. - /// - /// - /// A boolean specifying whether the placeholder to be created is ragged. - /// - /// - /// Optional existing tensor to wrap into the `Input` layer. - /// If set, the layer will not create a placeholder tensor. - /// - /// - public Tensor Input(TensorShape shape = null, - int batch_size = -1, - TF_DataType dtype = TF_DataType.DtInvalid, - string name = null, - bool sparse = false, - bool ragged = false, - Tensor tensor = null) - { - var args = new InputLayerArgs - { - Name = name, - InputShape = shape, - BatchSize = batch_size, - DType = dtype, - Sparse = sparse, - Ragged = ragged, - InputTensor = tensor - }; - - var layer = new InputLayer(args); - - return layer.InboundNodes[0].Outputs; - } + public static KerasInterface keras { get; } = new KerasInterface(); } } diff --git a/src/TensorFlowNET.Keras/KerasExtension.cs b/src/TensorFlowNET.Keras/KerasExtension.cs deleted file mode 100644 index 75420201..00000000 --- a/src/TensorFlowNET.Keras/KerasExtension.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Tensorflow.Keras; - -namespace Tensorflow -{ - public static class KerasExt - { - public static KerasApi Keras(this tensorflow tf) - => new KerasApi(); - - public static KerasApi keras { get; } = new KerasApi(); - } -} diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs new file mode 100644 index 00000000..5455148f --- /dev/null +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -0,0 +1,84 @@ +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Datasets; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; + +namespace Tensorflow.Keras +{ + public class KerasInterface + { + public KerasDataset datasets { get; } = new KerasDataset(); + public Initializers initializers { get; } = new Initializers(); + public Regularizers regularizers { get; } = new Regularizers(); + public LayersApi layers { get; } = new LayersApi(); + public LossesApi losses { get; } = new LossesApi(); + public Activations activations { get; } = new Activations(); + public Preprocessing preprocessing { get; } = new Preprocessing(); + public BackendImpl backend { get; } = new BackendImpl(); + public OptimizerApi optimizers { get; } = new OptimizerApi(); + public MetricsApi metrics { get; } = new MetricsApi(); + + public Sequential Sequential(List layers = null, + string name = null) + => new Sequential(new SequentialArgs + { + Layers = layers, + Name = name + }); + + /// + /// `Model` groups layers into an object with training and inference features. + /// + /// + /// + /// + public Functional Model(Tensors inputs, Tensors outputs, string name = null) + => new Functional(inputs, outputs, name: name); + + /// + /// Instantiate a Keras tensor. + /// + /// + /// + /// + /// + /// + /// A boolean specifying whether the placeholder to be created is sparse. + /// + /// + /// A boolean specifying whether the placeholder to be created is ragged. + /// + /// + /// Optional existing tensor to wrap into the `Input` layer. + /// If set, the layer will not create a placeholder tensor. + /// + /// + public Tensor Input(TensorShape shape = null, + int batch_size = -1, + TF_DataType dtype = TF_DataType.DtInvalid, + string name = null, + bool sparse = false, + bool ragged = false, + Tensor tensor = null) + { + var args = new InputLayerArgs + { + Name = name, + InputShape = shape, + BatchSize = batch_size, + DType = dtype, + Sparse = sparse, + Ragged = ragged, + InputTensor = tensor + }; + + var layer = new InputLayer(args); + + return layer.InboundNodes[0].Outputs; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/InputLayer.cs index b58ce777..32b566ea 100644 --- a/src/TensorFlowNET.Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/InputLayer.cs @@ -19,7 +19,7 @@ using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Layers { @@ -50,6 +50,7 @@ namespace Tensorflow.Keras.Layers { var prefix = "input"; name = prefix + '_' + keras.backend.get_uid(prefix); + args.Name = name; } if (args.DType == TF_DataType.DtInvalid) @@ -99,5 +100,8 @@ namespace Tensorflow.Keras.Layers tf.Context.restore_mode(); } + + public override LayerArgs get_config() + => args; } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index d99f4041..7711dd16 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -2,7 +2,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Layers { diff --git a/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs index 6e479c70..7f6ff3e7 100644 --- a/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs @@ -2,7 +2,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Layers { diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index 73563556..c9af1915 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -1,4 +1,4 @@ -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras { diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 35ae759f..9c03b503 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -35,4 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac + + + + diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index f0d314d2..45ec506f 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -21,7 +21,7 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Utils { diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 1de763a1..c2839cdc 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -16,11 +16,22 @@ using System; using System.Linq; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Utils { public class generic_utils { + public static LayerConfig serialize_keras_object(ILayer instance) + { + var config = instance.get_config(); + return new LayerConfig + { + Config = config, + ClassName = instance.GetType().Name + }; + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index c539919c..ebdcf757 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1 + net5.0 AnyCPU;x64 diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index ad91de80..477c2ae9 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -1,6 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace TensorFlowNET.UnitTest.Keras { diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs index 9eb2f39a..886e30fc 100644 --- a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs @@ -1,6 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.Keras.Engine; -using static Tensorflow.KerasExt; +using static Tensorflow.KerasApi; namespace TensorFlowNET.UnitTest.Keras { @@ -10,21 +10,20 @@ namespace TensorFlowNET.UnitTest.Keras [TestClass] public class ModelSaveTest : EagerModeTestBase { - [TestMethod] - public void SaveAndLoadTest() + [TestMethod, Ignore] + public void GetAndFromConfig() { - var model = GetModel(); + var model = GetFunctionalModel(); + var config = model.get_config(); } - Model GetModel() + Functional GetFunctionalModel() { // Create a simple model. var inputs = keras.Input(shape: 32); var dense_layer = keras.layers.Dense(1); var outputs = dense_layer.Apply(inputs); - var model = keras.Model(inputs, outputs); - model.compile("adam", "mean_squared_error"); - return model; + return keras.Model(inputs, outputs); } } }