@@ -0,0 +1,64 @@ | |||
namespace Tensorflow.Keras.Engine; | |||
/// <summary> | |||
/// A representation of a Keras in/output during Functional API construction. | |||
/// </summary> | |||
public class KerasTensor | |||
{ | |||
private Tensors _original_tensors; | |||
public Tensors original_tensors | |||
{ | |||
get => _original_tensors; | |||
set => _original_tensors = value; | |||
} | |||
private Shape _inferred_value; | |||
public Shape inferred_value => _inferred_value; | |||
private string _name; | |||
private TensorSpec _type_spec; | |||
public Shape shape => _type_spec.shape; | |||
public TF_DataType dtype => _type_spec.dtype; | |||
public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null) | |||
{ | |||
_type_spec = type_spec; | |||
_inferred_value = inferred_value; | |||
_name = name; | |||
} | |||
public static KerasTensor from_tensor(Tensor tensor) | |||
{ | |||
var type_spec = tensor.ToTensorSpec(); | |||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||
kt.original_tensors = tensor; | |||
return kt; | |||
} | |||
public override string ToString() | |||
=> _original_tensors.Length switch | |||
{ | |||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]", | |||
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}", | |||
_ => _original_tensors.ToString(), | |||
}; | |||
private string GetInferredValueString() | |||
=> _inferred_value == null ? "" : ""; | |||
public static implicit operator Tensors(KerasTensor kt) | |||
=> kt._original_tensors; | |||
public static implicit operator Tensor(KerasTensor kt) | |||
{ | |||
Tensor tensor = kt._original_tensors; | |||
tensor.IsFromKerasTensor = true; | |||
return tensor; | |||
} | |||
public static implicit operator KerasTensor(Tensor tensor) | |||
=> from_tensor(tensor); | |||
public static implicit operator KerasTensor(Tensors tensors) | |||
=> from_tensor(tensors.First()); | |||
} |
@@ -1,53 +0,0 @@ | |||
namespace Tensorflow.Keras.Engine; | |||
/// <summary> | |||
/// A representation of a Keras in/output during Functional API construction. | |||
/// </summary> | |||
public class KerasTensor | |||
{ | |||
private Tensors _inferred_value; | |||
public Tensors inferred_value | |||
{ | |||
get => _inferred_value; | |||
set => _inferred_value = value; | |||
} | |||
private string _name; | |||
private TensorSpec _type_spec; | |||
public Shape shape => _type_spec.shape; | |||
public TF_DataType dtype => _type_spec.dtype; | |||
public KerasTensor(TensorSpec type_spec, string name = null) | |||
{ | |||
_type_spec = type_spec; | |||
_name = name; | |||
} | |||
public static KerasTensor from_tensor(Tensor tensor) | |||
{ | |||
var type_spec = tensor.ToTensorSpec(); | |||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||
kt.inferred_value = tensor; | |||
return kt; | |||
} | |||
public override string ToString() | |||
=> _inferred_value.Length switch | |||
{ | |||
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]", | |||
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>", | |||
_ => _inferred_value.ToString(), | |||
}; | |||
public static implicit operator Tensors(KerasTensor kt) | |||
=> kt._inferred_value; | |||
public static implicit operator Tensor(KerasTensor kt) | |||
=> kt._inferred_value; | |||
public static implicit operator KerasTensor(Tensor tensor) | |||
=> from_tensor(tensor); | |||
public static implicit operator KerasTensor(Tensors tensors) | |||
=> from_tensor(tensors.First()); | |||
} |
@@ -14,19 +14,10 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Text; | |||
using Tensorflow.Framework.Models; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow; | |||
namespace Tensorflow | |||
public partial class Tensor | |||
{ | |||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | |||
public partial class Tensor | |||
{ | |||
public TensorSpec ToTensorSpec() | |||
=> new TensorSpec(shape, dtype, name); | |||
} | |||
public TensorSpec ToTensorSpec() | |||
=> new TensorSpec(shape, dtype, name); | |||
} |
@@ -0,0 +1,27 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
namespace Tensorflow; | |||
public partial class Tensor | |||
{ | |||
public bool IsFromKerasTensor { get; set; } | |||
/// <summary> | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public KerasHistory KerasHistory { get; set; } | |||
} |
@@ -146,11 +146,6 @@ namespace Tensorflow | |||
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); | |||
} | |||
/// <summary> | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public KerasHistory KerasHistory { get; set; } | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
/// </summary> | |||
@@ -1,22 +1,15 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
namespace Tensorflow.Keras.Models | |||
namespace Tensorflow.Keras.Models; | |||
public class ModelsApi: IModelsApi | |||
{ | |||
public class ModelsApi: IModelsApi | |||
{ | |||
public Functional from_config(FunctionalConfig config) | |||
=> Functional.from_config(config); | |||
public Functional from_config(FunctionalConfig config) | |||
=> Functional.from_config(config); | |||
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
{ | |||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||
} | |||
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
{ | |||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||
} | |||
} |
@@ -1,97 +1,89 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using System.IO; | |||
using Tensorflow.Train; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
namespace Tensorflow.Keras.Saving.SavedModel; | |||
public class KerasLoadModelUtils | |||
{ | |||
public class KerasLoadModelUtils | |||
/// <summary> | |||
/// Corresponding to keras/saving/save.py/load_model | |||
/// </summary> | |||
/// <param name="filepath"></param> | |||
/// <param name="custom_objects"></param> | |||
/// <param name="compile"></param> | |||
/// <param name="options"></param> | |||
/// <returns></returns> | |||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||
bool compile = true, LoadOptions? options = null) | |||
{ | |||
/// <summary> | |||
/// Corresponding to keras/saving/save.py/load_model | |||
/// </summary> | |||
/// <param name="filepath"></param> | |||
/// <param name="custom_objects"></param> | |||
/// <param name="compile"></param> | |||
/// <param name="options"></param> | |||
/// <returns></returns> | |||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||
bool compile = true, LoadOptions? options = null) | |||
using var savingScope = SharedObjectSavingScope.Enter(); | |||
using var ctx = LoadContext.load_context(options); | |||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||
{ | |||
using (SharedObjectSavingScope.Enter()) | |||
{ | |||
using (LoadContext.load_context(options)) | |||
{ | |||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||
{ | |||
throw new IOException($"No file or directory found at {filepath}."); | |||
} | |||
if (Directory.Exists(filepath)) | |||
{ | |||
return load(filepath, compile, options); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||
} | |||
} | |||
} | |||
throw new IOException($"No file or directory found at {filepath}."); | |||
} | |||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
if (Directory.Exists(filepath)) | |||
{ | |||
return load(filepath, compile, options); | |||
} | |||
else | |||
{ | |||
SavedMetadata metadata = new SavedMetadata(); | |||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||
if (File.Exists(path_to_metadata_pb)) | |||
{ | |||
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" + | |||
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it."); | |||
} | |||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||
} | |||
} | |||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||
{ | |||
return Loader.load(path, options: options) as Model; | |||
} | |||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
{ | |||
SavedMetadata metadata; | |||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||
if (File.Exists(path_to_metadata_pb)) | |||
{ | |||
using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read); | |||
metadata = SavedMetadata.Parser.ParseFrom(stream); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" + | |||
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it."); | |||
} | |||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
keras_loader.load_layers(compile: compile); | |||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||
{ | |||
return Loader.load(path, options: options) as Model; | |||
} | |||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||
nodes_to_load["root"] = (null, null); | |||
foreach(var item in keras_loader.LoadedNodes) | |||
{ | |||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||
} | |||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
keras_loader.load_layers(compile: compile); | |||
keras_loader.finalize_objects(); | |||
keras_loader.del_tracking(); | |||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||
nodes_to_load["root"] = (null, null); | |||
foreach(var item in keras_loader.LoadedNodes) | |||
{ | |||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||
} | |||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||
var model = loaded["root"]; | |||
keras_loader.finalize_objects(); | |||
keras_loader.del_tracking(); | |||
if(model is Model && compile) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
var model = loaded["root"]; | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
if (model is Model && compile) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
return model; | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
return model; | |||
} | |||
} |