diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
new file mode 100644
index 00000000..9287284f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
@@ -0,0 +1,64 @@
+namespace Tensorflow.Keras.Engine;
+
+///
+/// A representation of a Keras in/output during Functional API construction.
+///
+public class KerasTensor
+{
+ private Tensors _original_tensors;
+ public Tensors original_tensors
+ {
+ get => _original_tensors;
+ set => _original_tensors = value;
+ }
+
+ private Shape _inferred_value;
+ public Shape inferred_value => _inferred_value;
+
+ private string _name;
+ private TensorSpec _type_spec;
+ public Shape shape => _type_spec.shape;
+ public TF_DataType dtype => _type_spec.dtype;
+
+ public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null)
+ {
+ _type_spec = type_spec;
+ _inferred_value = inferred_value;
+ _name = name;
+ }
+
+ public static KerasTensor from_tensor(Tensor tensor)
+ {
+ var type_spec = tensor.ToTensorSpec();
+ var kt = new KerasTensor(type_spec, name: tensor.name);
+ kt.original_tensors = tensor;
+ return kt;
+ }
+
+ public override string ToString()
+ => _original_tensors.Length switch
+ {
+ > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
+ 1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
+ _ => _original_tensors.ToString(),
+ };
+
+ private string GetInferredValueString()
+ => _inferred_value == null ? "" : "";
+
+ public static implicit operator Tensors(KerasTensor kt)
+ => kt._original_tensors;
+
+ public static implicit operator Tensor(KerasTensor kt)
+ {
+ Tensor tensor = kt._original_tensors;
+ tensor.IsFromKerasTensor = true;
+ return tensor;
+ }
+
+ public static implicit operator KerasTensor(Tensor tensor)
+ => from_tensor(tensor);
+
+ public static implicit operator KerasTensor(Tensors tensors)
+ => from_tensor(tensors.First());
+}
diff --git a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs
deleted file mode 100644
index 3204b4ac..00000000
--- a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs
+++ /dev/null
@@ -1,53 +0,0 @@
-namespace Tensorflow.Keras.Engine;
-
-///
-/// A representation of a Keras in/output during Functional API construction.
-///
-public class KerasTensor
-{
- private Tensors _inferred_value;
- public Tensors inferred_value
- {
- get => _inferred_value;
- set => _inferred_value = value;
- }
-
- private string _name;
- private TensorSpec _type_spec;
- public Shape shape => _type_spec.shape;
- public TF_DataType dtype => _type_spec.dtype;
-
- public KerasTensor(TensorSpec type_spec, string name = null)
- {
- _type_spec = type_spec;
- _name = name;
- }
-
- public static KerasTensor from_tensor(Tensor tensor)
- {
- var type_spec = tensor.ToTensorSpec();
- var kt = new KerasTensor(type_spec, name: tensor.name);
- kt.inferred_value = tensor;
- return kt;
- }
-
- public override string ToString()
- => _inferred_value.Length switch
- {
- > 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"")) + "]",
- 1 => $"",
- _ => _inferred_value.ToString(),
- };
-
- public static implicit operator Tensors(KerasTensor kt)
- => kt._inferred_value;
-
- public static implicit operator Tensor(KerasTensor kt)
- => kt._inferred_value;
-
- public static implicit operator KerasTensor(Tensor tensor)
- => from_tensor(tensor);
-
- public static implicit operator KerasTensor(Tensors tensors)
- => from_tensor(tensors.First());
-}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
index 18bdc1aa..fdd62aee 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
@@ -14,19 +14,10 @@
limitations under the License.
******************************************************************************/
-using Tensorflow.NumPy;
-using System;
-using System.Diagnostics.CodeAnalysis;
-using System.Text;
-using Tensorflow.Framework.Models;
-using static Tensorflow.Binding;
+namespace Tensorflow;
-namespace Tensorflow
+public partial class Tensor
{
- [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
- public partial class Tensor
- {
- public TensorSpec ToTensorSpec()
- => new TensorSpec(shape, dtype, name);
- }
+ public TensorSpec ToTensorSpec()
+ => new TensorSpec(shape, dtype, name);
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs
new file mode 100644
index 00000000..ca946ca4
--- /dev/null
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Keras.cs
@@ -0,0 +1,27 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+namespace Tensorflow;
+
+public partial class Tensor
+{
+ public bool IsFromKerasTensor { get; set; }
+
+ ///
+ /// Keras History: (Layer, (node_index, tensor_index))
+ ///
+ public KerasHistory KerasHistory { get; set; }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index c0e5d435..65e1c857 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -146,11 +146,6 @@ namespace Tensorflow
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
}
- ///
- /// Keras History: (Layer, (node_index, tensor_index))
- ///
- public KerasHistory KerasHistory { get; set; }
-
///
/// Updates the shape of this tensor.
///
diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs
index 44dca58d..2605c41e 100644
--- a/src/TensorFlowNET.Keras/Models/ModelsApi.cs
+++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs
@@ -1,22 +1,15 @@
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Text;
-using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Saving;
+using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
-using ThirdParty.Tensorflow.Python.Keras.Protobuf;
-namespace Tensorflow.Keras.Models
+namespace Tensorflow.Keras.Models;
+
+public class ModelsApi: IModelsApi
{
- public class ModelsApi: IModelsApi
- {
- public Functional from_config(FunctionalConfig config)
- => Functional.from_config(config);
+ public Functional from_config(FunctionalConfig config)
+ => Functional.from_config(config);
- public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
- {
- return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
- }
+ public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
+ {
+ return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
}
}
diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
index aa763fc2..091dbb81 100644
--- a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
+++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
@@ -1,97 +1,89 @@
-using Google.Protobuf;
-using System;
-using System.Collections.Generic;
-using System.IO;
-using System.Text;
-using Tensorflow.Keras.Engine;
+using System.IO;
using Tensorflow.Train;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
-using static Tensorflow.Binding;
-using static Tensorflow.KerasApi;
-namespace Tensorflow.Keras.Saving.SavedModel
+namespace Tensorflow.Keras.Saving.SavedModel;
+
+public class KerasLoadModelUtils
{
- public class KerasLoadModelUtils
+ ///
+ /// Corresponding to keras/saving/save.py/load_model
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Trackable load_model(string filepath, IDictionary? custom_objects = null,
+ bool compile = true, LoadOptions? options = null)
{
- ///
- /// Corresponding to keras/saving/save.py/load_model
- ///
- ///
- ///
- ///
- ///
- ///
- public static Trackable load_model(string filepath, IDictionary? custom_objects = null,
- bool compile = true, LoadOptions? options = null)
+ using var savingScope = SharedObjectSavingScope.Enter();
+
+ using var ctx = LoadContext.load_context(options);
+
+ if (!File.Exists(filepath) && !Directory.Exists(filepath))
{
- using (SharedObjectSavingScope.Enter())
- {
- using (LoadContext.load_context(options))
- {
- if (!File.Exists(filepath) && !Directory.Exists(filepath))
- {
- throw new IOException($"No file or directory found at {filepath}.");
- }
- if (Directory.Exists(filepath))
- {
- return load(filepath, compile, options);
- }
- else
- {
- throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
- }
- }
- }
+ throw new IOException($"No file or directory found at {filepath}.");
}
- private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
+ if (Directory.Exists(filepath))
+ {
+ return load(filepath, compile, options);
+ }
+ else
{
- SavedMetadata metadata = new SavedMetadata();
- var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
- var object_graph_def = meta_graph_def.ObjectGraphDef;
- string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
- if (File.Exists(path_to_metadata_pb))
- {
- metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
- }
- else
- {
- throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
- " use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
- }
+ throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
+ }
+ }
- if (metadata.Nodes is null || metadata.Nodes.Count == 0)
- {
- return Loader.load(path, options: options) as Model;
- }
+ private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
+ {
+ SavedMetadata metadata;
+ var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
+ var object_graph_def = meta_graph_def.ObjectGraphDef;
+ string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
+ if (File.Exists(path_to_metadata_pb))
+ {
+ using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read);
+ metadata = SavedMetadata.Parser.ParseFrom(stream);
+ }
+ else
+ {
+ throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
+ " use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
+ }
- var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
- keras_loader.load_layers(compile: compile);
+ if (metadata.Nodes is null || metadata.Nodes.Count == 0)
+ {
+ return Loader.load(path, options: options) as Model;
+ }
- Dictionary)> nodes_to_load = new();
- nodes_to_load["root"] = (null, null);
- foreach(var item in keras_loader.LoadedNodes)
- {
- nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
- }
- var loaded = Loader.load_partial(path, nodes_to_load, options);
+ var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
+ keras_loader.load_layers(compile: compile);
- keras_loader.finalize_objects();
- keras_loader.del_tracking();
+ Dictionary)> nodes_to_load = new();
+ nodes_to_load["root"] = (null, null);
+ foreach(var item in keras_loader.LoadedNodes)
+ {
+ nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
+ }
+ var loaded = Loader.load_partial(path, nodes_to_load, options);
- var model = loaded["root"];
+ keras_loader.finalize_objects();
+ keras_loader.del_tracking();
- if(model is Model && compile)
- {
- // TODO(Rinne): implement it.
- }
+ var model = loaded["root"];
- if (!tf.Context.executing_eagerly())
- {
- // TODO(Rinne): implement it.
- }
+ if (model is Model && compile)
+ {
+ // TODO(Rinne): implement it.
+ }
- return model;
+ if (!tf.Context.executing_eagerly())
+ {
+ // TODO(Rinne): implement it.
}
+
+ return model;
}
}