@@ -2,7 +2,7 @@ | |||||
<PropertyGroup> | <PropertyGroup> | ||||
<OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||||
<TargetFramework>net5.0</TargetFramework> | |||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<AssemblyName>Tensorflow</AssemblyName> | <AssemblyName>Tensorflow</AssemblyName> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -1,4 +1,6 @@ | |||||
using System.Collections.Generic; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -10,5 +12,7 @@ namespace Tensorflow.Keras.Engine | |||||
List<Tensor> KerasInputs { get; set; } | List<Tensor> KerasInputs { get; set; } | ||||
INode[] ParentNodes { get; } | INode[] ParentNodes { get; } | ||||
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | ||||
bool is_input { get; } | |||||
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
@@ -14,5 +15,6 @@ namespace Tensorflow.Keras | |||||
List<IVariableV1> trainable_variables { get; } | List<IVariableV1> trainable_variables { get; } | ||||
TensorShape output_shape { get; } | TensorShape output_shape { get; } | ||||
int count_params(); | int count_params(); | ||||
LayerArgs get_config(); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class LayerConfig | |||||
{ | |||||
public string Name { get; set; } | |||||
public string ClassName { get; set; } | |||||
public LayerArgs Config { get; set; } | |||||
public List<INode> InboundNodes { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Engine; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class ModelConfig | |||||
{ | |||||
public string Name { get; set; } | |||||
public List<LayerConfig> Layers { get; set; } | |||||
public List<ILayer> InputLayers { get; set; } | |||||
public List<ILayer> OutputLayers { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class NodeConfig | |||||
{ | |||||
public string Name { get; set; } | |||||
public int NodeIndex { get; set; } | |||||
public int TensorIndex { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.ModelSaving | |||||
{ | |||||
public class ModelSaver | |||||
{ | |||||
public void save(Trackable obj, string export_dir, SaveOptions options = null) | |||||
{ | |||||
var saved_model = new SavedModel(); | |||||
var meta_graph_def = new MetaGraphDef(); | |||||
saved_model.MetaGraphs.Add(meta_graph_def); | |||||
_build_meta_graph(obj, export_dir, options, meta_graph_def); | |||||
} | |||||
void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options, | |||||
MetaGraphDef meta_graph_def = null) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.ModelSaving | |||||
{ | |||||
/// <summary> | |||||
/// Options for saving to SavedModel. | |||||
/// </summary> | |||||
public class SaveOptions | |||||
{ | |||||
bool save_debug_info; | |||||
public SaveOptions(bool save_debug_info = false) | |||||
{ | |||||
this.save_debug_info = save_debug_info; | |||||
} | |||||
} | |||||
} |
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
@@ -132,5 +133,10 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
public LayerArgs get_config() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto | |||||
ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf | ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto | ||||
@@ -0,0 +1,210 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: tensorflow/core/protobuf/saved_model.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto</summary> | |||||
public static partial class SavedModelReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for tensorflow/core/protobuf/saved_model.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static SavedModelReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S", | |||||
"CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy", | |||||
"YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h", | |||||
"X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm", | |||||
"bG93Lk1ldGFHcmFwaERlZkJ7ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", | |||||
"EFNhdmVkTW9kZWxQcm90b3NQAVpIZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2NvcmVfcHJvdG9zX2dvX3By", | |||||
"b3Rv+AEBYgZwcm90bzM=")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, }, | |||||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// SavedModel is the high level serialization format for TensorFlow Models. | |||||
/// See [todo: doc links, similar to session_bundle] for more information. | |||||
/// </summary> | |||||
public sealed partial class SavedModel : pb::IMessage<SavedModel> { | |||||
private static readonly pb::MessageParser<SavedModel> _parser = new pb::MessageParser<SavedModel>(() => new SavedModel()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<SavedModel> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SavedModel() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SavedModel(SavedModel other) : this() { | |||||
savedModelSchemaVersion_ = other.savedModelSchemaVersion_; | |||||
metaGraphs_ = other.metaGraphs_.Clone(); | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SavedModel Clone() { | |||||
return new SavedModel(this); | |||||
} | |||||
/// <summary>Field number for the "saved_model_schema_version" field.</summary> | |||||
public const int SavedModelSchemaVersionFieldNumber = 1; | |||||
private long savedModelSchemaVersion_; | |||||
/// <summary> | |||||
/// The schema version of the SavedModel instance. Used for versioning when | |||||
/// making future changes to the specification/implementation. Initial value | |||||
/// at release will be 1. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public long SavedModelSchemaVersion { | |||||
get { return savedModelSchemaVersion_; } | |||||
set { | |||||
savedModelSchemaVersion_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "meta_graphs" field.</summary> | |||||
public const int MetaGraphsFieldNumber = 2; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.MetaGraphDef> _repeated_metaGraphs_codec | |||||
= pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.MetaGraphDef> metaGraphs_ = new pbc::RepeatedField<global::Tensorflow.MetaGraphDef>(); | |||||
/// <summary> | |||||
/// One or more MetaGraphs. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.MetaGraphDef> MetaGraphs { | |||||
get { return metaGraphs_; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as SavedModel); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(SavedModel other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false; | |||||
if(!metaGraphs_.Equals(other.metaGraphs_)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode(); | |||||
hash ^= metaGraphs_.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
if (SavedModelSchemaVersion != 0L) { | |||||
output.WriteRawTag(8); | |||||
output.WriteInt64(SavedModelSchemaVersion); | |||||
} | |||||
metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (SavedModelSchemaVersion != 0L) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion); | |||||
} | |||||
size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec); | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(SavedModel other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.SavedModelSchemaVersion != 0L) { | |||||
SavedModelSchemaVersion = other.SavedModelSchemaVersion; | |||||
} | |||||
metaGraphs_.Add(other.metaGraphs_); | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 8: { | |||||
SavedModelSchemaVersion = input.ReadInt64(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -0,0 +1,76 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Functional | |||||
{ | |||||
public ModelConfig get_config() | |||||
{ | |||||
return get_network_config(); | |||||
} | |||||
/// <summary> | |||||
/// Builds the config, which consists of the node graph and serialized layers. | |||||
/// </summary> | |||||
ModelConfig get_network_config() | |||||
{ | |||||
var config = new ModelConfig | |||||
{ | |||||
Name = name | |||||
}; | |||||
var node_conversion_map = new Dictionary<string, int>(); | |||||
foreach (var layer in _layers) | |||||
{ | |||||
var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | |||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
{ | |||||
var node_key = _make_node_key(layer.Name, original_node_index); | |||||
if (NetworkNodes.Contains(node_key)) | |||||
{ | |||||
node_conversion_map[node_key] = kept_nodes; | |||||
kept_nodes += 1; | |||||
} | |||||
} | |||||
} | |||||
var layer_configs = new List<LayerConfig>(); | |||||
foreach (var layer in _layers) | |||||
{ | |||||
var filtered_inbound_nodes = new List<INode>(); | |||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
{ | |||||
var node_key = _make_node_key(layer.Name, original_node_index); | |||||
if (NetworkNodes.Contains(node_key) && !node.is_input) | |||||
{ | |||||
var node_data = node.serialize(_make_node_key, node_conversion_map); | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
var layer_config = generic_utils.serialize_keras_object(layer); | |||||
layer_config.Name = layer.Name; | |||||
layer_config.InboundNodes = filtered_inbound_nodes; | |||||
layer_configs.Add(layer_config); | |||||
} | |||||
config.Layers = layer_configs; | |||||
return config; | |||||
} | |||||
bool _should_skip_first_node(ILayer layer) | |||||
{ | |||||
return layer is Functional && layer.Layers[0] is InputLayer; | |||||
} | |||||
string _make_node_key(string layer_name, int node_index) | |||||
=> $"{layer_name}_ib-{node_index}"; | |||||
} | |||||
} |
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <summary> | /// <summary> | ||||
/// A `Functional` model is a `Model` defined as a directed graph of layers. | /// A `Functional` model is a `Model` defined as a directed graph of layers. | ||||
/// </summary> | /// </summary> | ||||
public class Functional : Model | |||||
public partial class Functional : Model | |||||
{ | { | ||||
TensorShape _build_input_shape; | TensorShape _build_input_shape; | ||||
bool _compute_output_and_mask_jointly; | bool _compute_output_and_mask_jointly; | ||||
@@ -283,7 +283,7 @@ namespace Tensorflow.Keras.Engine | |||||
// Propagate to all previous tensors connected to this node. | // Propagate to all previous tensors connected to this node. | ||||
nodes_in_progress.Add(node); | nodes_in_progress.Add(node); | ||||
if (!node.IsInput) | |||||
if (!node.is_input) | |||||
{ | { | ||||
foreach (var k_tensor in node.KerasInputs) | foreach (var k_tensor in node.KerasInputs) | ||||
{ | { | ||||
@@ -307,7 +307,7 @@ namespace Tensorflow.Keras.Engine | |||||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | ||||
{ | { | ||||
if (mask != null) | |||||
if (mask == null) | |||||
{ | { | ||||
Tensor[] masks = new Tensor[inputs.Count()]; | Tensor[] masks = new Tensor[inputs.Count()]; | ||||
foreach (var (i, input_t) in enumerate(inputs)) | foreach (var (i, input_t) in enumerate(inputs)) | ||||
@@ -330,7 +330,7 @@ namespace Tensorflow.Keras.Engine | |||||
foreach (Node node in nodes) | foreach (Node node in nodes) | ||||
{ | { | ||||
// Input tensors already exist. | // Input tensors already exist. | ||||
if (node.IsInput) | |||||
if (node.is_input) | |||||
continue; | continue; | ||||
var layer_inputs = node.MapArguments(tensor_dict); | var layer_inputs = node.MapArguments(tensor_dict); | ||||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Threading; | using System.Threading; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -241,5 +242,8 @@ namespace Tensorflow.Keras.Engine | |||||
return weights; | return weights; | ||||
} | } | ||||
} | } | ||||
public virtual LayerArgs get_config() | |||||
=> throw new NotImplementedException(""); | |||||
} | } | ||||
} | } |
@@ -2,7 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -1,13 +1,26 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.ModelSaving; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
public partial class Model | public partial class Model | ||||
{ | { | ||||
public void save(string path) | |||||
ModelSaver saver = new ModelSaver(); | |||||
/// <summary> | |||||
/// Saves the model to Tensorflow SavedModel or a single HDF5 file. | |||||
/// </summary> | |||||
/// <param name="filepath"></param> | |||||
/// <param name="overwrite"></param> | |||||
/// <param name="include_optimizer"></param> | |||||
public void save(string filepath, | |||||
bool overwrite = true, | |||||
bool include_optimizer = true, | |||||
string save_format = "tf", | |||||
SaveOptions options = null) | |||||
{ | { | ||||
saver.save(this, filepath); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Node | |||||
{ | |||||
/// <summary> | |||||
/// Serializes `Node` for Functional API's `get_config`. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Engine | |||||
public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | ||||
public ILayer Layer { get; set; } | public ILayer Layer { get; set; } | ||||
public bool IsInput => args.InputTensors == null; | |||||
public bool is_input => args.InputTensors == null; | |||||
public int[] FlatInputIds { get; set; } | public int[] FlatInputIds { get; set; } | ||||
public int[] FlatOutputIds { get; set; } | public int[] FlatOutputIds { get; set; } | ||||
bool _single_positional_tensor_passed => KerasInputs.Count() == 1; | bool _single_positional_tensor_passed => KerasInputs.Count() == 1; | ||||
@@ -17,7 +17,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -1,84 +1,12 @@ | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Datasets; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using Tensorflow.Keras; | |||||
namespace Tensorflow.Keras | |||||
namespace Tensorflow | |||||
{ | { | ||||
public class KerasApi | |||||
public static class KerasApi | |||||
{ | { | ||||
public KerasDataset datasets { get; } = new KerasDataset(); | |||||
public Initializers initializers { get; } = new Initializers(); | |||||
public Regularizers regularizers { get; } = new Regularizers(); | |||||
public LayersApi layers { get; } = new LayersApi(); | |||||
public LossesApi losses { get; } = new LossesApi(); | |||||
public Activations activations { get; } = new Activations(); | |||||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||||
public BackendImpl backend { get; } = new BackendImpl(); | |||||
public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||||
public MetricsApi metrics { get; } = new MetricsApi(); | |||||
public static KerasInterface Keras(this tensorflow tf) | |||||
=> new KerasInterface(); | |||||
public Sequential Sequential(List<ILayer> layers = null, | |||||
string name = null) | |||||
=> new Sequential(new SequentialArgs | |||||
{ | |||||
Layers = layers, | |||||
Name = name | |||||
}); | |||||
/// <summary> | |||||
/// `Model` groups layers into an object with training and inference features. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="output"></param> | |||||
/// <returns></returns> | |||||
public Functional Model(Tensors inputs, Tensors outputs, string name = null) | |||||
=> new Functional(inputs, outputs, name: name); | |||||
/// <summary> | |||||
/// Instantiate a Keras tensor. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="batch_size"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="sparse"> | |||||
/// A boolean specifying whether the placeholder to be created is sparse. | |||||
/// </param> | |||||
/// <param name="ragged"> | |||||
/// A boolean specifying whether the placeholder to be created is ragged. | |||||
/// </param> | |||||
/// <param name="tensor"> | |||||
/// Optional existing tensor to wrap into the `Input` layer. | |||||
/// If set, the layer will not create a placeholder tensor. | |||||
/// </param> | |||||
/// <returns></returns> | |||||
public Tensor Input(TensorShape shape = null, | |||||
int batch_size = -1, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
string name = null, | |||||
bool sparse = false, | |||||
bool ragged = false, | |||||
Tensor tensor = null) | |||||
{ | |||||
var args = new InputLayerArgs | |||||
{ | |||||
Name = name, | |||||
InputShape = shape, | |||||
BatchSize = batch_size, | |||||
DType = dtype, | |||||
Sparse = sparse, | |||||
Ragged = ragged, | |||||
InputTensor = tensor | |||||
}; | |||||
var layer = new InputLayer(args); | |||||
return layer.InboundNodes[0].Outputs; | |||||
} | |||||
public static KerasInterface keras { get; } = new KerasInterface(); | |||||
} | } | ||||
} | } |
@@ -1,12 +0,0 @@ | |||||
using Tensorflow.Keras; | |||||
namespace Tensorflow | |||||
{ | |||||
public static class KerasExt | |||||
{ | |||||
public static KerasApi Keras(this tensorflow tf) | |||||
=> new KerasApi(); | |||||
public static KerasApi keras { get; } = new KerasApi(); | |||||
} | |||||
} |
@@ -0,0 +1,84 @@ | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Datasets; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | |||||
using Tensorflow.Keras.Optimizers; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class KerasInterface | |||||
{ | |||||
public KerasDataset datasets { get; } = new KerasDataset(); | |||||
public Initializers initializers { get; } = new Initializers(); | |||||
public Regularizers regularizers { get; } = new Regularizers(); | |||||
public LayersApi layers { get; } = new LayersApi(); | |||||
public LossesApi losses { get; } = new LossesApi(); | |||||
public Activations activations { get; } = new Activations(); | |||||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||||
public BackendImpl backend { get; } = new BackendImpl(); | |||||
public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||||
public MetricsApi metrics { get; } = new MetricsApi(); | |||||
public Sequential Sequential(List<ILayer> layers = null, | |||||
string name = null) | |||||
=> new Sequential(new SequentialArgs | |||||
{ | |||||
Layers = layers, | |||||
Name = name | |||||
}); | |||||
/// <summary> | |||||
/// `Model` groups layers into an object with training and inference features. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="output"></param> | |||||
/// <returns></returns> | |||||
public Functional Model(Tensors inputs, Tensors outputs, string name = null) | |||||
=> new Functional(inputs, outputs, name: name); | |||||
/// <summary> | |||||
/// Instantiate a Keras tensor. | |||||
/// </summary> | |||||
/// <param name="shape"></param> | |||||
/// <param name="batch_size"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="sparse"> | |||||
/// A boolean specifying whether the placeholder to be created is sparse. | |||||
/// </param> | |||||
/// <param name="ragged"> | |||||
/// A boolean specifying whether the placeholder to be created is ragged. | |||||
/// </param> | |||||
/// <param name="tensor"> | |||||
/// Optional existing tensor to wrap into the `Input` layer. | |||||
/// If set, the layer will not create a placeholder tensor. | |||||
/// </param> | |||||
/// <returns></returns> | |||||
public Tensor Input(TensorShape shape = null, | |||||
int batch_size = -1, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
string name = null, | |||||
bool sparse = false, | |||||
bool ragged = false, | |||||
Tensor tensor = null) | |||||
{ | |||||
var args = new InputLayerArgs | |||||
{ | |||||
Name = name, | |||||
InputShape = shape, | |||||
BatchSize = batch_size, | |||||
DType = dtype, | |||||
Sparse = sparse, | |||||
Ragged = ragged, | |||||
InputTensor = tensor | |||||
}; | |||||
var layer = new InputLayer(args); | |||||
return layer.InboundNodes[0].Outputs; | |||||
} | |||||
} | |||||
} |
@@ -19,7 +19,7 @@ using Tensorflow.Framework.Models; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -50,6 +50,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
var prefix = "input"; | var prefix = "input"; | ||||
name = prefix + '_' + keras.backend.get_uid(prefix); | name = prefix + '_' + keras.backend.get_uid(prefix); | ||||
args.Name = name; | |||||
} | } | ||||
if (args.DType == TF_DataType.DtInvalid) | if (args.DType == TF_DataType.DtInvalid) | ||||
@@ -99,5 +100,8 @@ namespace Tensorflow.Keras.Layers | |||||
tf.Context.restore_mode(); | tf.Context.restore_mode(); | ||||
} | } | ||||
public override LayerArgs get_config() | |||||
=> args; | |||||
} | } | ||||
} | } |
@@ -2,7 +2,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -2,7 +2,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -1,4 +1,4 @@ | |||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
@@ -35,4 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | <ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<Folder Include="Saving\" /> | |||||
</ItemGroup> | |||||
</Project> | </Project> |
@@ -21,7 +21,7 @@ using System.Linq; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
{ | { | ||||
@@ -16,11 +16,22 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
{ | { | ||||
public class generic_utils | public class generic_utils | ||||
{ | { | ||||
public static LayerConfig serialize_keras_object(ILayer instance) | |||||
{ | |||||
var config = instance.get_config(); | |||||
return new LayerConfig | |||||
{ | |||||
Config = config, | |||||
ClassName = instance.GetType().Name | |||||
}; | |||||
} | |||||
public static string to_snake_case(string name) | public static string to_snake_case(string name) | ||||
{ | { | ||||
return string.Concat(name.Select((x, i) => | return string.Concat(name.Select((x, i) => | ||||
@@ -2,7 +2,7 @@ | |||||
<PropertyGroup> | <PropertyGroup> | ||||
<OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||||
<TargetFramework>net5.0</TargetFramework> | |||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -1,6 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | using NumSharp; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.UnitTest.Keras | namespace TensorFlowNET.UnitTest.Keras | ||||
{ | { | ||||
@@ -1,6 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.KerasExt; | |||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.UnitTest.Keras | namespace TensorFlowNET.UnitTest.Keras | ||||
{ | { | ||||
@@ -10,21 +10,20 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
[TestClass] | [TestClass] | ||||
public class ModelSaveTest : EagerModeTestBase | public class ModelSaveTest : EagerModeTestBase | ||||
{ | { | ||||
[TestMethod] | |||||
public void SaveAndLoadTest() | |||||
[TestMethod, Ignore] | |||||
public void GetAndFromConfig() | |||||
{ | { | ||||
var model = GetModel(); | |||||
var model = GetFunctionalModel(); | |||||
var config = model.get_config(); | |||||
} | } | ||||
Model GetModel() | |||||
Functional GetFunctionalModel() | |||||
{ | { | ||||
// Create a simple model. | // Create a simple model. | ||||
var inputs = keras.Input(shape: 32); | var inputs = keras.Input(shape: 32); | ||||
var dense_layer = keras.layers.Dense(1); | var dense_layer = keras.layers.Dense(1); | ||||
var outputs = dense_layer.Apply(inputs); | var outputs = dense_layer.Apply(inputs); | ||||
var model = keras.Model(inputs, outputs); | |||||
model.compile("adam", "mean_squared_error"); | |||||
return model; | |||||
return keras.Model(inputs, outputs); | |||||
} | } | ||||
} | } | ||||
} | } |