@@ -0,0 +1,64 @@ | |||||
namespace Tensorflow.Keras.Engine; | |||||
/// <summary> | |||||
/// A representation of a Keras in/output during Functional API construction. | |||||
/// </summary> | |||||
public class KerasTensor | |||||
{ | |||||
private Tensors _original_tensors; | |||||
public Tensors original_tensors | |||||
{ | |||||
get => _original_tensors; | |||||
set => _original_tensors = value; | |||||
} | |||||
private Shape _inferred_value; | |||||
public Shape inferred_value => _inferred_value; | |||||
private string _name; | |||||
private TensorSpec _type_spec; | |||||
public Shape shape => _type_spec.shape; | |||||
public TF_DataType dtype => _type_spec.dtype; | |||||
public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||||
{ | |||||
_type_spec = type_spec; | |||||
_inferred_value = inferred_value; | |||||
_name = name; | |||||
} | |||||
public static KerasTensor from_tensor(Tensor tensor) | |||||
{ | |||||
var type_spec = tensor.ToTensorSpec(); | |||||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||||
kt.original_tensors = tensor; | |||||
return kt; | |||||
} | |||||
public override string ToString() | |||||
=> _original_tensors.Length switch | |||||
{ | |||||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]", | |||||
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}", | |||||
_ => _original_tensors.ToString(), | |||||
}; | |||||
private string GetInferredValueString() | |||||
=> _inferred_value == null ? "" : ""; | |||||
public static implicit operator Tensors(KerasTensor kt) | |||||
=> kt._original_tensors; | |||||
public static implicit operator Tensor(KerasTensor kt) | |||||
{ | |||||
Tensor tensor = kt._original_tensors; | |||||
tensor.IsFromKerasTensor = true; | |||||
return tensor; | |||||
} | |||||
public static implicit operator KerasTensor(Tensor tensor) | |||||
=> from_tensor(tensor); | |||||
public static implicit operator KerasTensor(Tensors tensors) | |||||
=> from_tensor(tensors.First()); | |||||
} |
@@ -1,53 +0,0 @@ | |||||
namespace Tensorflow.Keras.Engine; | |||||
/// <summary> | |||||
/// A representation of a Keras in/output during Functional API construction. | |||||
/// </summary> | |||||
public class KerasTensor | |||||
{ | |||||
private Tensors _inferred_value; | |||||
public Tensors inferred_value | |||||
{ | |||||
get => _inferred_value; | |||||
set => _inferred_value = value; | |||||
} | |||||
private string _name; | |||||
private TensorSpec _type_spec; | |||||
public Shape shape => _type_spec.shape; | |||||
public TF_DataType dtype => _type_spec.dtype; | |||||
public KerasTensor(TensorSpec type_spec, string name = null) | |||||
{ | |||||
_type_spec = type_spec; | |||||
_name = name; | |||||
} | |||||
public static KerasTensor from_tensor(Tensor tensor) | |||||
{ | |||||
var type_spec = tensor.ToTensorSpec(); | |||||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||||
kt.inferred_value = tensor; | |||||
return kt; | |||||
} | |||||
public override string ToString() | |||||
=> _inferred_value.Length switch | |||||
{ | |||||
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]", | |||||
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>", | |||||
_ => _inferred_value.ToString(), | |||||
}; | |||||
public static implicit operator Tensors(KerasTensor kt) | |||||
=> kt._inferred_value; | |||||
public static implicit operator Tensor(KerasTensor kt) | |||||
=> kt._inferred_value; | |||||
public static implicit operator KerasTensor(Tensor tensor) | |||||
=> from_tensor(tensor); | |||||
public static implicit operator KerasTensor(Tensors tensors) | |||||
=> from_tensor(tensors.First()); | |||||
} |
@@ -14,19 +14,10 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using System.Diagnostics.CodeAnalysis; | |||||
using System.Text; | |||||
using Tensorflow.Framework.Models; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow; | |||||
namespace Tensorflow | |||||
public partial class Tensor | |||||
{ | { | ||||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | |||||
public partial class Tensor | |||||
{ | |||||
public TensorSpec ToTensorSpec() | |||||
=> new TensorSpec(shape, dtype, name); | |||||
} | |||||
public TensorSpec ToTensorSpec() | |||||
=> new TensorSpec(shape, dtype, name); | |||||
} | } |
@@ -0,0 +1,27 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
namespace Tensorflow; | |||||
public partial class Tensor | |||||
{ | |||||
public bool IsFromKerasTensor { get; set; } | |||||
/// <summary> | |||||
/// Keras History: (Layer, (node_index, tensor_index)) | |||||
/// </summary> | |||||
public KerasHistory KerasHistory { get; set; } | |||||
} |
@@ -146,11 +146,6 @@ namespace Tensorflow | |||||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | ||||
} | } | ||||
/// <summary> | |||||
/// Keras History: (Layer, (node_index, tensor_index)) | |||||
/// </summary> | |||||
public KerasHistory KerasHistory { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -1,22 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
namespace Tensorflow.Keras.Models | |||||
namespace Tensorflow.Keras.Models; | |||||
public class ModelsApi: IModelsApi | |||||
{ | { | ||||
public class ModelsApi: IModelsApi | |||||
{ | |||||
public Functional from_config(FunctionalConfig config) | |||||
=> Functional.from_config(config); | |||||
public Functional from_config(FunctionalConfig config) | |||||
=> Functional.from_config(config); | |||||
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||||
{ | |||||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||||
} | |||||
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||||
{ | |||||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||||
} | } | ||||
} | } |
@@ -1,97 +1,89 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Engine; | |||||
using System.IO; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
namespace Tensorflow.Keras.Saving.SavedModel; | |||||
public class KerasLoadModelUtils | |||||
{ | { | ||||
public class KerasLoadModelUtils | |||||
/// <summary> | |||||
/// Corresponding to keras/saving/save.py/load_model | |||||
/// </summary> | |||||
/// <param name="filepath"></param> | |||||
/// <param name="custom_objects"></param> | |||||
/// <param name="compile"></param> | |||||
/// <param name="options"></param> | |||||
/// <returns></returns> | |||||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||||
bool compile = true, LoadOptions? options = null) | |||||
{ | { | ||||
/// <summary> | |||||
/// Corresponding to keras/saving/save.py/load_model | |||||
/// </summary> | |||||
/// <param name="filepath"></param> | |||||
/// <param name="custom_objects"></param> | |||||
/// <param name="compile"></param> | |||||
/// <param name="options"></param> | |||||
/// <returns></returns> | |||||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||||
bool compile = true, LoadOptions? options = null) | |||||
using var savingScope = SharedObjectSavingScope.Enter(); | |||||
using var ctx = LoadContext.load_context(options); | |||||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||||
{ | { | ||||
using (SharedObjectSavingScope.Enter()) | |||||
{ | |||||
using (LoadContext.load_context(options)) | |||||
{ | |||||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||||
{ | |||||
throw new IOException($"No file or directory found at {filepath}."); | |||||
} | |||||
if (Directory.Exists(filepath)) | |||||
{ | |||||
return load(filepath, compile, options); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||||
} | |||||
} | |||||
} | |||||
throw new IOException($"No file or directory found at {filepath}."); | |||||
} | } | ||||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
if (Directory.Exists(filepath)) | |||||
{ | |||||
return load(filepath, compile, options); | |||||
} | |||||
else | |||||
{ | { | ||||
SavedMetadata metadata = new SavedMetadata(); | |||||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||||
if (File.Exists(path_to_metadata_pb)) | |||||
{ | |||||
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" + | |||||
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it."); | |||||
} | |||||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||||
} | |||||
} | |||||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||||
{ | |||||
return Loader.load(path, options: options) as Model; | |||||
} | |||||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
{ | |||||
SavedMetadata metadata; | |||||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||||
if (File.Exists(path_to_metadata_pb)) | |||||
{ | |||||
using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read); | |||||
metadata = SavedMetadata.Parser.ParseFrom(stream); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" + | |||||
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it."); | |||||
} | |||||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
keras_loader.load_layers(compile: compile); | |||||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||||
{ | |||||
return Loader.load(path, options: options) as Model; | |||||
} | |||||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||||
nodes_to_load["root"] = (null, null); | |||||
foreach(var item in keras_loader.LoadedNodes) | |||||
{ | |||||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||||
} | |||||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
keras_loader.load_layers(compile: compile); | |||||
keras_loader.finalize_objects(); | |||||
keras_loader.del_tracking(); | |||||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||||
nodes_to_load["root"] = (null, null); | |||||
foreach(var item in keras_loader.LoadedNodes) | |||||
{ | |||||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||||
} | |||||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||||
var model = loaded["root"]; | |||||
keras_loader.finalize_objects(); | |||||
keras_loader.del_tracking(); | |||||
if(model is Model && compile) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
var model = loaded["root"]; | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
if (model is Model && compile) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
return model; | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | } | ||||
return model; | |||||
} | } | ||||
} | } |