diff --git a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
index 6cc631f4..89884977 100644
--- a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
+++ b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp3.1
+ net5.0
Tensorflow
Tensorflow
diff --git a/src/TensorFlowNET.Core/Keras/Engine/INode.cs b/src/TensorFlowNET.Core/Keras/Engine/INode.cs
index 1305a772..dde0f8ea 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/INode.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/INode.cs
@@ -1,4 +1,6 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Engine
{
@@ -10,5 +12,7 @@ namespace Tensorflow.Keras.Engine
List KerasInputs { get; set; }
INode[] ParentNodes { get; }
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
+ bool is_input { get; }
+ NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map);
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
index c096458f..d6bbf11a 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
@@ -1,4 +1,5 @@
using System.Collections.Generic;
+using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras
@@ -14,5 +15,6 @@ namespace Tensorflow.Keras
List trainable_variables { get; }
TensorShape output_shape { get; }
int count_params();
+ LayerArgs get_config();
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
new file mode 100644
index 00000000..950b3132
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Saving
+{
+ public class LayerConfig
+ {
+ public string Name { get; set; }
+ public string ClassName { get; set; }
+ public LayerArgs Config { get; set; }
+ public List InboundNodes { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
new file mode 100644
index 00000000..fa965aa4
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Saving
+{
+ public class ModelConfig
+ {
+ public string Name { get; set; }
+ public List Layers { get; set; }
+ public List InputLayers { get; set; }
+ public List OutputLayers { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
new file mode 100644
index 00000000..732d9d4d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Saving
+{
+ public class NodeConfig
+ {
+ public string Name { get; set; }
+ public int NodeIndex { get; set; }
+ public int TensorIndex { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
new file mode 100644
index 00000000..4437ba0a
--- /dev/null
+++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
@@ -0,0 +1,25 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Train;
+
+namespace Tensorflow.ModelSaving
+{
+ public class ModelSaver
+ {
+ public void save(Trackable obj, string export_dir, SaveOptions options = null)
+ {
+ var saved_model = new SavedModel();
+ var meta_graph_def = new MetaGraphDef();
+ saved_model.MetaGraphs.Add(meta_graph_def);
+ _build_meta_graph(obj, export_dir, options, meta_graph_def);
+ }
+
+ void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options,
+ MetaGraphDef meta_graph_def = null)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
new file mode 100644
index 00000000..e25537d8
--- /dev/null
+++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.ModelSaving
+{
+ ///
+ /// Options for saving to SavedModel.
+ ///
+ public class SaveOptions
+ {
+ bool save_debug_info;
+ public SaveOptions(bool save_debug_info = false)
+ {
+ this.save_debug_info = save_debug_info;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 3c717825..b2e1566b 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras;
+using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using Tensorflow.Util;
@@ -132,5 +133,10 @@ namespace Tensorflow
{
throw new NotImplementedException();
}
+
+ public LayerArgs get_config()
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Protobuf/Gen.bat b/src/TensorFlowNET.Core/Protobuf/Gen.bat
index 745235af..c6256737 100644
--- a/src/TensorFlowNET.Core/Protobuf/Gen.bat
+++ b/src/TensorFlowNET.Core/Protobuf/Gen.bat
@@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto
ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto
diff --git a/src/TensorFlowNET.Core/Protobuf/SavedModel.cs b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs
new file mode 100644
index 00000000..e7b9259a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/SavedModel.cs
@@ -0,0 +1,210 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: tensorflow/core/protobuf/saved_model.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto
+ public static partial class SavedModelReflection {
+
+ #region Descriptor
+ /// File descriptor for tensorflow/core/protobuf/saved_model.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static SavedModelReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S",
+ "CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy",
+ "YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h",
+ "X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm",
+ "bG93Lk1ldGFHcmFwaERlZkJ7ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC",
+ "EFNhdmVkTW9kZWxQcm90b3NQAVpIZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl",
+ "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2NvcmVfcHJvdG9zX2dvX3By",
+ "b3Rv+AEBYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// SavedModel is the high level serialization format for TensorFlow Models.
+ /// See [todo: doc links, similar to session_bundle] for more information.
+ ///
+ public sealed partial class SavedModel : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SavedModel());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SavedModel() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SavedModel(SavedModel other) : this() {
+ savedModelSchemaVersion_ = other.savedModelSchemaVersion_;
+ metaGraphs_ = other.metaGraphs_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public SavedModel Clone() {
+ return new SavedModel(this);
+ }
+
+ /// Field number for the "saved_model_schema_version" field.
+ public const int SavedModelSchemaVersionFieldNumber = 1;
+ private long savedModelSchemaVersion_;
+ ///
+ /// The schema version of the SavedModel instance. Used for versioning when
+ /// making future changes to the specification/implementation. Initial value
+ /// at release will be 1.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public long SavedModelSchemaVersion {
+ get { return savedModelSchemaVersion_; }
+ set {
+ savedModelSchemaVersion_ = value;
+ }
+ }
+
+ /// Field number for the "meta_graphs" field.
+ public const int MetaGraphsFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_metaGraphs_codec
+ = pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser);
+ private readonly pbc::RepeatedField metaGraphs_ = new pbc::RepeatedField();
+ ///
+ /// One or more MetaGraphs.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField MetaGraphs {
+ get { return metaGraphs_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as SavedModel);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(SavedModel other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false;
+ if(!metaGraphs_.Equals(other.metaGraphs_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode();
+ hash ^= metaGraphs_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (SavedModelSchemaVersion != 0L) {
+ output.WriteRawTag(8);
+ output.WriteInt64(SavedModelSchemaVersion);
+ }
+ metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (SavedModelSchemaVersion != 0L) {
+ size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion);
+ }
+ size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(SavedModel other) {
+ if (other == null) {
+ return;
+ }
+ if (other.SavedModelSchemaVersion != 0L) {
+ SavedModelSchemaVersion = other.SavedModelSchemaVersion;
+ }
+ metaGraphs_.Add(other.metaGraphs_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ SavedModelSchemaVersion = input.ReadInt64();
+ break;
+ }
+ case 18: {
+ metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
new file mode 100644
index 00000000..96ea11b6
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
@@ -0,0 +1,76 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Keras.Layers;
+using Tensorflow.Keras.Saving;
+using Tensorflow.Keras.Utils;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Keras.Engine
+{
+ public partial class Functional
+ {
+ public ModelConfig get_config()
+ {
+ return get_network_config();
+ }
+
+ ///
+ /// Builds the config, which consists of the node graph and serialized layers.
+ ///
+ ModelConfig get_network_config()
+ {
+ var config = new ModelConfig
+ {
+ Name = name
+ };
+
+ var node_conversion_map = new Dictionary();
+ foreach (var layer in _layers)
+ {
+ var kept_nodes = _should_skip_first_node(layer) ? 1 : 0;
+ foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
+ {
+ var node_key = _make_node_key(layer.Name, original_node_index);
+ if (NetworkNodes.Contains(node_key))
+ {
+ node_conversion_map[node_key] = kept_nodes;
+ kept_nodes += 1;
+ }
+ }
+ }
+
+ var layer_configs = new List();
+ foreach (var layer in _layers)
+ {
+ var filtered_inbound_nodes = new List();
+ foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
+ {
+ var node_key = _make_node_key(layer.Name, original_node_index);
+ if (NetworkNodes.Contains(node_key) && !node.is_input)
+ {
+ var node_data = node.serialize(_make_node_key, node_conversion_map);
+ throw new NotImplementedException("");
+ }
+ }
+
+ var layer_config = generic_utils.serialize_keras_object(layer);
+ layer_config.Name = layer.Name;
+ layer_config.InboundNodes = filtered_inbound_nodes;
+ layer_configs.Add(layer_config);
+ }
+ config.Layers = layer_configs;
+
+ return config;
+ }
+
+ bool _should_skip_first_node(ILayer layer)
+ {
+ return layer is Functional && layer.Layers[0] is InputLayer;
+ }
+
+ string _make_node_key(string layer_name, int node_index)
+ => $"{layer_name}_ib-{node_index}";
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs
index 68027b07..d55b6781 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine
///
/// A `Functional` model is a `Model` defined as a directed graph of layers.
///
- public class Functional : Model
+ public partial class Functional : Model
{
TensorShape _build_input_shape;
bool _compute_output_and_mask_jointly;
@@ -283,7 +283,7 @@ namespace Tensorflow.Keras.Engine
// Propagate to all previous tensors connected to this node.
nodes_in_progress.Add(node);
- if (!node.IsInput)
+ if (!node.is_input)
{
foreach (var k_tensor in node.KerasInputs)
{
@@ -307,7 +307,7 @@ namespace Tensorflow.Keras.Engine
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)
{
- if (mask != null)
+ if (mask == null)
{
Tensor[] masks = new Tensor[inputs.Count()];
foreach (var (i, input_t) in enumerate(inputs))
@@ -330,7 +330,7 @@ namespace Tensorflow.Keras.Engine
foreach (Node node in nodes)
{
// Input tensors already exist.
- if (node.IsInput)
+ if (node.is_input)
continue;
var layer_inputs = node.MapArguments(tensor_dict);
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index 0c892535..93c9f91f 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;
@@ -241,5 +242,8 @@ namespace Tensorflow.Keras.Engine
return weights;
}
}
+
+ public virtual LayerArgs get_config()
+ => throw new NotImplementedException("");
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index 32159fb0..a72e67fd 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Metrics;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Engine
{
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs
index 58f2be1a..c287309d 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs
@@ -1,13 +1,26 @@
using System.Collections.Generic;
using Tensorflow.Keras.Metrics;
+using Tensorflow.ModelSaving;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
- public void save(string path)
+ ModelSaver saver = new ModelSaver();
+
+ ///
+ /// Saves the model to Tensorflow SavedModel or a single HDF5 file.
+ ///
+ ///
+ ///
+ ///
+ public void save(string filepath,
+ bool overwrite = true,
+ bool include_optimizer = true,
+ string save_format = "tf",
+ SaveOptions options = null)
{
-
+ saver.save(this, filepath);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs
new file mode 100644
index 00000000..05d544f8
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.Engine
+{
+ public partial class Node
+ {
+ ///
+ /// Serializes `Node` for Functional API's `get_config`.
+ ///
+ ///
+ public NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map)
+ {
+ throw new NotImplementedException("");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Engine/Node.cs b/src/TensorFlowNET.Keras/Engine/Node.cs
index 1601a6a6..d78e5533 100644
--- a/src/TensorFlowNET.Keras/Engine/Node.cs
+++ b/src/TensorFlowNET.Keras/Engine/Node.cs
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Engine
public TensorShape[] output_shapes;
public List KerasInputs { get; set; } = new List();
public ILayer Layer { get; set; }
- public bool IsInput => args.InputTensors == null;
+ public bool is_input => args.InputTensors == null;
public int[] FlatInputIds { get; set; }
public int[] FlatOutputIds { get; set; }
bool _single_positional_tensor_passed => KerasInputs.Count() == 1;
diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs
index 5f262083..50974cf7 100644
--- a/src/TensorFlowNET.Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs
@@ -17,7 +17,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Engine
{
diff --git a/src/TensorFlowNET.Keras/KerasApi.cs b/src/TensorFlowNET.Keras/KerasApi.cs
index 97d28943..a22c0399 100644
--- a/src/TensorFlowNET.Keras/KerasApi.cs
+++ b/src/TensorFlowNET.Keras/KerasApi.cs
@@ -1,84 +1,12 @@
-using System.Collections.Generic;
-using Tensorflow.Keras.ArgsDefinition;
-using Tensorflow.Keras.Datasets;
-using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Layers;
-using Tensorflow.Keras.Losses;
-using Tensorflow.Keras.Metrics;
-using Tensorflow.Keras.Optimizers;
+using Tensorflow.Keras;
-namespace Tensorflow.Keras
+namespace Tensorflow
{
- public class KerasApi
+ public static class KerasApi
{
- public KerasDataset datasets { get; } = new KerasDataset();
- public Initializers initializers { get; } = new Initializers();
- public Regularizers regularizers { get; } = new Regularizers();
- public LayersApi layers { get; } = new LayersApi();
- public LossesApi losses { get; } = new LossesApi();
- public Activations activations { get; } = new Activations();
- public Preprocessing preprocessing { get; } = new Preprocessing();
- public BackendImpl backend { get; } = new BackendImpl();
- public OptimizerApi optimizers { get; } = new OptimizerApi();
- public MetricsApi metrics { get; } = new MetricsApi();
+ public static KerasInterface Keras(this tensorflow tf)
+ => new KerasInterface();
- public Sequential Sequential(List layers = null,
- string name = null)
- => new Sequential(new SequentialArgs
- {
- Layers = layers,
- Name = name
- });
-
- ///
- /// `Model` groups layers into an object with training and inference features.
- ///
- ///
- ///
- ///
- public Functional Model(Tensors inputs, Tensors outputs, string name = null)
- => new Functional(inputs, outputs, name: name);
-
- ///
- /// Instantiate a Keras tensor.
- ///
- ///
- ///
- ///
- ///
- ///
- /// A boolean specifying whether the placeholder to be created is sparse.
- ///
- ///
- /// A boolean specifying whether the placeholder to be created is ragged.
- ///
- ///
- /// Optional existing tensor to wrap into the `Input` layer.
- /// If set, the layer will not create a placeholder tensor.
- ///
- ///
- public Tensor Input(TensorShape shape = null,
- int batch_size = -1,
- TF_DataType dtype = TF_DataType.DtInvalid,
- string name = null,
- bool sparse = false,
- bool ragged = false,
- Tensor tensor = null)
- {
- var args = new InputLayerArgs
- {
- Name = name,
- InputShape = shape,
- BatchSize = batch_size,
- DType = dtype,
- Sparse = sparse,
- Ragged = ragged,
- InputTensor = tensor
- };
-
- var layer = new InputLayer(args);
-
- return layer.InboundNodes[0].Outputs;
- }
+ public static KerasInterface keras { get; } = new KerasInterface();
}
}
diff --git a/src/TensorFlowNET.Keras/KerasExtension.cs b/src/TensorFlowNET.Keras/KerasExtension.cs
deleted file mode 100644
index 75420201..00000000
--- a/src/TensorFlowNET.Keras/KerasExtension.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using Tensorflow.Keras;
-
-namespace Tensorflow
-{
- public static class KerasExt
- {
- public static KerasApi Keras(this tensorflow tf)
- => new KerasApi();
-
- public static KerasApi keras { get; } = new KerasApi();
- }
-}
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
new file mode 100644
index 00000000..5455148f
--- /dev/null
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -0,0 +1,84 @@
+using System.Collections.Generic;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Datasets;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers;
+using Tensorflow.Keras.Losses;
+using Tensorflow.Keras.Metrics;
+using Tensorflow.Keras.Optimizers;
+
+namespace Tensorflow.Keras
+{
+ public class KerasInterface
+ {
+ public KerasDataset datasets { get; } = new KerasDataset();
+ public Initializers initializers { get; } = new Initializers();
+ public Regularizers regularizers { get; } = new Regularizers();
+ public LayersApi layers { get; } = new LayersApi();
+ public LossesApi losses { get; } = new LossesApi();
+ public Activations activations { get; } = new Activations();
+ public Preprocessing preprocessing { get; } = new Preprocessing();
+ public BackendImpl backend { get; } = new BackendImpl();
+ public OptimizerApi optimizers { get; } = new OptimizerApi();
+ public MetricsApi metrics { get; } = new MetricsApi();
+
+ public Sequential Sequential(List layers = null,
+ string name = null)
+ => new Sequential(new SequentialArgs
+ {
+ Layers = layers,
+ Name = name
+ });
+
+ ///
+ /// `Model` groups layers into an object with training and inference features.
+ ///
+ ///
+ ///
+ ///
+ public Functional Model(Tensors inputs, Tensors outputs, string name = null)
+ => new Functional(inputs, outputs, name: name);
+
+ ///
+ /// Instantiate a Keras tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// A boolean specifying whether the placeholder to be created is sparse.
+ ///
+ ///
+ /// A boolean specifying whether the placeholder to be created is ragged.
+ ///
+ ///
+ /// Optional existing tensor to wrap into the `Input` layer.
+ /// If set, the layer will not create a placeholder tensor.
+ ///
+ ///
+ public Tensor Input(TensorShape shape = null,
+ int batch_size = -1,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ string name = null,
+ bool sparse = false,
+ bool ragged = false,
+ Tensor tensor = null)
+ {
+ var args = new InputLayerArgs
+ {
+ Name = name,
+ InputShape = shape,
+ BatchSize = batch_size,
+ DType = dtype,
+ Sparse = sparse,
+ Ragged = ragged,
+ InputTensor = tensor
+ };
+
+ var layer = new InputLayer(args);
+
+ return layer.InboundNodes[0].Outputs;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/InputLayer.cs
index b58ce777..32b566ea 100644
--- a/src/TensorFlowNET.Keras/Layers/InputLayer.cs
+++ b/src/TensorFlowNET.Keras/Layers/InputLayer.cs
@@ -19,7 +19,7 @@ using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Layers
{
@@ -50,6 +50,7 @@ namespace Tensorflow.Keras.Layers
{
var prefix = "input";
name = prefix + '_' + keras.backend.get_uid(prefix);
+ args.Name = name;
}
if (args.DType == TF_DataType.DtInvalid)
@@ -99,5 +100,8 @@ namespace Tensorflow.Keras.Layers
tf.Context.restore_mode();
}
+
+ public override LayerArgs get_config()
+ => args;
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index d99f4041..7711dd16 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -2,7 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Layers
{
diff --git a/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs
index 6e479c70..7f6ff3e7 100644
--- a/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs
@@ -2,7 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Layers
{
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
index 73563556..c9af1915 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
@@ -1,4 +1,4 @@
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras
{
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 35ae759f..9c03b503 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -35,4 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
+
+
+
+
diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
index f0d314d2..45ec506f 100644
--- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
@@ -21,7 +21,7 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Utils
{
diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs
index 1de763a1..c2839cdc 100644
--- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs
@@ -16,11 +16,22 @@
using System;
using System.Linq;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Utils
{
public class generic_utils
{
+ public static LayerConfig serialize_keras_object(ILayer instance)
+ {
+ var config = instance.get_config();
+ return new LayerConfig
+ {
+ Config = config,
+ ClassName = instance.GetType().Name
+ };
+ }
+
public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>
diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
index c539919c..ebdcf757 100644
--- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
+++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp3.1
+ net5.0
AnyCPU;x64
diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
index ad91de80..477c2ae9 100644
--- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace TensorFlowNET.UnitTest.Keras
{
diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
index 9eb2f39a..886e30fc 100644
--- a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Engine;
-using static Tensorflow.KerasExt;
+using static Tensorflow.KerasApi;
namespace TensorFlowNET.UnitTest.Keras
{
@@ -10,21 +10,20 @@ namespace TensorFlowNET.UnitTest.Keras
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
- [TestMethod]
- public void SaveAndLoadTest()
+ [TestMethod, Ignore]
+ public void GetAndFromConfig()
{
- var model = GetModel();
+ var model = GetFunctionalModel();
+ var config = model.get_config();
}
- Model GetModel()
+ Functional GetFunctionalModel()
{
// Create a simple model.
var inputs = keras.Input(shape: 32);
var dense_layer = keras.layers.Dense(1);
var outputs = dense_layer.Apply(inputs);
- var model = keras.Model(inputs, outputs);
- model.compile("adam", "mean_squared_error");
- return model;
+ return keras.Model(inputs, outputs);
}
}
}