@@ -2,7 +2,7 @@ | |||
<PropertyGroup> | |||
<OutputType>Exe</OutputType> | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<TargetFramework>net5.0</TargetFramework> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<AssemblyName>Tensorflow</AssemblyName> | |||
</PropertyGroup> | |||
@@ -1,4 +1,6 @@ | |||
using System.Collections.Generic; | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -10,5 +12,7 @@ namespace Tensorflow.Keras.Engine | |||
List<Tensor> KerasInputs { get; set; } | |||
INode[] ParentNodes { get; } | |||
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | |||
bool is_input { get; } | |||
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras | |||
@@ -14,5 +15,6 @@ namespace Tensorflow.Keras | |||
List<IVariableV1> trainable_variables { get; } | |||
TensorShape output_shape { get; } | |||
int count_params(); | |||
LayerArgs get_config(); | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class LayerConfig | |||
{ | |||
public string Name { get; set; } | |||
public string ClassName { get; set; } | |||
public LayerArgs Config { get; set; } | |||
public List<INode> InboundNodes { get; set; } | |||
} | |||
} |
@@ -0,0 +1,15 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class ModelConfig | |||
{ | |||
public string Name { get; set; } | |||
public List<LayerConfig> Layers { get; set; } | |||
public List<ILayer> InputLayers { get; set; } | |||
public List<ILayer> OutputLayers { get; set; } | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class NodeConfig | |||
{ | |||
public string Name { get; set; } | |||
public int NodeIndex { get; set; } | |||
public int TensorIndex { get; set; } | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.ModelSaving | |||
{ | |||
public class ModelSaver | |||
{ | |||
public void save(Trackable obj, string export_dir, SaveOptions options = null) | |||
{ | |||
var saved_model = new SavedModel(); | |||
var meta_graph_def = new MetaGraphDef(); | |||
saved_model.MetaGraphs.Add(meta_graph_def); | |||
_build_meta_graph(obj, export_dir, options, meta_graph_def); | |||
} | |||
void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options, | |||
MetaGraphDef meta_graph_def = null) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.ModelSaving | |||
{ | |||
/// <summary> | |||
/// Options for saving to SavedModel. | |||
/// </summary> | |||
public class SaveOptions | |||
{ | |||
bool save_debug_info; | |||
public SaveOptions(bool save_debug_info = false) | |||
{ | |||
this.save_debug_info = save_debug_info; | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
@@ -132,5 +133,10 @@ namespace Tensorflow | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public LayerArgs get_config() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto | |||
ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto | |||
@@ -0,0 +1,210 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: tensorflow/core/protobuf/saved_model.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto</summary> | |||
public static partial class SavedModelReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for tensorflow/core/protobuf/saved_model.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static SavedModelReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S", | |||
"CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy", | |||
"YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h", | |||
"X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm", | |||
"bG93Lk1ldGFHcmFwaERlZkJ7ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", | |||
"EFNhdmVkTW9kZWxQcm90b3NQAVpIZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2NvcmVfcHJvdG9zX2dvX3By", | |||
"b3Rv+AEBYgZwcm90bzM=")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, }, | |||
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// SavedModel is the high level serialization format for TensorFlow Models. | |||
/// See [todo: doc links, similar to session_bundle] for more information. | |||
/// </summary> | |||
public sealed partial class SavedModel : pb::IMessage<SavedModel> { | |||
private static readonly pb::MessageParser<SavedModel> _parser = new pb::MessageParser<SavedModel>(() => new SavedModel()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<SavedModel> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SavedModel() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SavedModel(SavedModel other) : this() { | |||
savedModelSchemaVersion_ = other.savedModelSchemaVersion_; | |||
metaGraphs_ = other.metaGraphs_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SavedModel Clone() { | |||
return new SavedModel(this); | |||
} | |||
/// <summary>Field number for the "saved_model_schema_version" field.</summary> | |||
public const int SavedModelSchemaVersionFieldNumber = 1; | |||
private long savedModelSchemaVersion_; | |||
/// <summary> | |||
/// The schema version of the SavedModel instance. Used for versioning when | |||
/// making future changes to the specification/implementation. Initial value | |||
/// at release will be 1. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public long SavedModelSchemaVersion { | |||
get { return savedModelSchemaVersion_; } | |||
set { | |||
savedModelSchemaVersion_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "meta_graphs" field.</summary> | |||
public const int MetaGraphsFieldNumber = 2; | |||
private static readonly pb::FieldCodec<global::Tensorflow.MetaGraphDef> _repeated_metaGraphs_codec | |||
= pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.MetaGraphDef> metaGraphs_ = new pbc::RepeatedField<global::Tensorflow.MetaGraphDef>(); | |||
/// <summary> | |||
/// One or more MetaGraphs. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<global::Tensorflow.MetaGraphDef> MetaGraphs { | |||
get { return metaGraphs_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as SavedModel); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(SavedModel other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false; | |||
if(!metaGraphs_.Equals(other.metaGraphs_)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode(); | |||
hash ^= metaGraphs_.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
if (SavedModelSchemaVersion != 0L) { | |||
output.WriteRawTag(8); | |||
output.WriteInt64(SavedModelSchemaVersion); | |||
} | |||
metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (SavedModelSchemaVersion != 0L) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion); | |||
} | |||
size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec); | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(SavedModel other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.SavedModelSchemaVersion != 0L) { | |||
SavedModelSchemaVersion = other.SavedModelSchemaVersion; | |||
} | |||
metaGraphs_.Add(other.metaGraphs_); | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 8: { | |||
SavedModelSchemaVersion = input.ReadInt64(); | |||
break; | |||
} | |||
case 18: { | |||
metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -0,0 +1,76 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Functional | |||
{ | |||
public ModelConfig get_config() | |||
{ | |||
return get_network_config(); | |||
} | |||
/// <summary> | |||
/// Builds the config, which consists of the node graph and serialized layers. | |||
/// </summary> | |||
ModelConfig get_network_config() | |||
{ | |||
var config = new ModelConfig | |||
{ | |||
Name = name | |||
}; | |||
var node_conversion_map = new Dictionary<string, int>(); | |||
foreach (var layer in _layers) | |||
{ | |||
var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | |||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||
{ | |||
var node_key = _make_node_key(layer.Name, original_node_index); | |||
if (NetworkNodes.Contains(node_key)) | |||
{ | |||
node_conversion_map[node_key] = kept_nodes; | |||
kept_nodes += 1; | |||
} | |||
} | |||
} | |||
var layer_configs = new List<LayerConfig>(); | |||
foreach (var layer in _layers) | |||
{ | |||
var filtered_inbound_nodes = new List<INode>(); | |||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||
{ | |||
var node_key = _make_node_key(layer.Name, original_node_index); | |||
if (NetworkNodes.Contains(node_key) && !node.is_input) | |||
{ | |||
var node_data = node.serialize(_make_node_key, node_conversion_map); | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
var layer_config = generic_utils.serialize_keras_object(layer); | |||
layer_config.Name = layer.Name; | |||
layer_config.InboundNodes = filtered_inbound_nodes; | |||
layer_configs.Add(layer_config); | |||
} | |||
config.Layers = layer_configs; | |||
return config; | |||
} | |||
bool _should_skip_first_node(ILayer layer) | |||
{ | |||
return layer is Functional && layer.Layers[0] is InputLayer; | |||
} | |||
string _make_node_key(string layer_name, int node_index) | |||
=> $"{layer_name}_ib-{node_index}"; | |||
} | |||
} |
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine | |||
/// <summary> | |||
/// A `Functional` model is a `Model` defined as a directed graph of layers. | |||
/// </summary> | |||
public class Functional : Model | |||
public partial class Functional : Model | |||
{ | |||
TensorShape _build_input_shape; | |||
bool _compute_output_and_mask_jointly; | |||
@@ -283,7 +283,7 @@ namespace Tensorflow.Keras.Engine | |||
// Propagate to all previous tensors connected to this node. | |||
nodes_in_progress.Add(node); | |||
if (!node.IsInput) | |||
if (!node.is_input) | |||
{ | |||
foreach (var k_tensor in node.KerasInputs) | |||
{ | |||
@@ -307,7 +307,7 @@ namespace Tensorflow.Keras.Engine | |||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||
{ | |||
if (mask != null) | |||
if (mask == null) | |||
{ | |||
Tensor[] masks = new Tensor[inputs.Count()]; | |||
foreach (var (i, input_t) in enumerate(inputs)) | |||
@@ -330,7 +330,7 @@ namespace Tensorflow.Keras.Engine | |||
foreach (Node node in nodes) | |||
{ | |||
// Input tensors already exist. | |||
if (node.IsInput) | |||
if (node.is_input) | |||
continue; | |||
var layer_inputs = node.MapArguments(tensor_dict); | |||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Threading; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
@@ -241,5 +242,8 @@ namespace Tensorflow.Keras.Engine | |||
return weights; | |||
} | |||
} | |||
public virtual LayerArgs get_config() | |||
=> throw new NotImplementedException(""); | |||
} | |||
} |
@@ -2,7 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.Metrics; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -1,13 +1,26 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.ModelSaving; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
public void save(string path) | |||
ModelSaver saver = new ModelSaver(); | |||
/// <summary> | |||
/// Saves the model to Tensorflow SavedModel or a single HDF5 file. | |||
/// </summary> | |||
/// <param name="filepath"></param> | |||
/// <param name="overwrite"></param> | |||
/// <param name="include_optimizer"></param> | |||
public void save(string filepath, | |||
bool overwrite = true, | |||
bool include_optimizer = true, | |||
string save_format = "tf", | |||
SaveOptions options = null) | |||
{ | |||
saver.save(this, filepath); | |||
} | |||
} | |||
} |
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Node | |||
{ | |||
/// <summary> | |||
/// Serializes `Node` for Functional API's `get_config`. | |||
/// </summary> | |||
/// <returns></returns> | |||
public NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Engine | |||
public TensorShape[] output_shapes; | |||
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | |||
public ILayer Layer { get; set; } | |||
public bool IsInput => args.InputTensors == null; | |||
public bool is_input => args.InputTensors == null; | |||
public int[] FlatInputIds { get; set; } | |||
public int[] FlatOutputIds { get; set; } | |||
bool _single_positional_tensor_passed => KerasInputs.Count() == 1; | |||
@@ -17,7 +17,7 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Layers; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -1,84 +1,12 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Datasets; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras; | |||
namespace Tensorflow.Keras | |||
namespace Tensorflow | |||
{ | |||
public class KerasApi | |||
public static class KerasApi | |||
{ | |||
public KerasDataset datasets { get; } = new KerasDataset(); | |||
public Initializers initializers { get; } = new Initializers(); | |||
public Regularizers regularizers { get; } = new Regularizers(); | |||
public LayersApi layers { get; } = new LayersApi(); | |||
public LossesApi losses { get; } = new LossesApi(); | |||
public Activations activations { get; } = new Activations(); | |||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
public BackendImpl backend { get; } = new BackendImpl(); | |||
public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||
public MetricsApi metrics { get; } = new MetricsApi(); | |||
public static KerasInterface Keras(this tensorflow tf) | |||
=> new KerasInterface(); | |||
public Sequential Sequential(List<ILayer> layers = null, | |||
string name = null) | |||
=> new Sequential(new SequentialArgs | |||
{ | |||
Layers = layers, | |||
Name = name | |||
}); | |||
/// <summary> | |||
/// `Model` groups layers into an object with training and inference features. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="output"></param> | |||
/// <returns></returns> | |||
public Functional Model(Tensors inputs, Tensors outputs, string name = null) | |||
=> new Functional(inputs, outputs, name: name); | |||
/// <summary> | |||
/// Instantiate a Keras tensor. | |||
/// </summary> | |||
/// <param name="shape"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
/// <param name="sparse"> | |||
/// A boolean specifying whether the placeholder to be created is sparse. | |||
/// </param> | |||
/// <param name="ragged"> | |||
/// A boolean specifying whether the placeholder to be created is ragged. | |||
/// </param> | |||
/// <param name="tensor"> | |||
/// Optional existing tensor to wrap into the `Input` layer. | |||
/// If set, the layer will not create a placeholder tensor. | |||
/// </param> | |||
/// <returns></returns> | |||
public Tensor Input(TensorShape shape = null, | |||
int batch_size = -1, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false, | |||
Tensor tensor = null) | |||
{ | |||
var args = new InputLayerArgs | |||
{ | |||
Name = name, | |||
InputShape = shape, | |||
BatchSize = batch_size, | |||
DType = dtype, | |||
Sparse = sparse, | |||
Ragged = ragged, | |||
InputTensor = tensor | |||
}; | |||
var layer = new InputLayer(args); | |||
return layer.InboundNodes[0].Outputs; | |||
} | |||
public static KerasInterface keras { get; } = new KerasInterface(); | |||
} | |||
} |
@@ -1,12 +0,0 @@ | |||
using Tensorflow.Keras; | |||
namespace Tensorflow | |||
{ | |||
public static class KerasExt | |||
{ | |||
public static KerasApi Keras(this tensorflow tf) | |||
=> new KerasApi(); | |||
public static KerasApi keras { get; } = new KerasApi(); | |||
} | |||
} |
@@ -0,0 +1,84 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Datasets; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Optimizers; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class KerasInterface | |||
{ | |||
public KerasDataset datasets { get; } = new KerasDataset(); | |||
public Initializers initializers { get; } = new Initializers(); | |||
public Regularizers regularizers { get; } = new Regularizers(); | |||
public LayersApi layers { get; } = new LayersApi(); | |||
public LossesApi losses { get; } = new LossesApi(); | |||
public Activations activations { get; } = new Activations(); | |||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
public BackendImpl backend { get; } = new BackendImpl(); | |||
public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||
public MetricsApi metrics { get; } = new MetricsApi(); | |||
public Sequential Sequential(List<ILayer> layers = null, | |||
string name = null) | |||
=> new Sequential(new SequentialArgs | |||
{ | |||
Layers = layers, | |||
Name = name | |||
}); | |||
/// <summary> | |||
/// `Model` groups layers into an object with training and inference features. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="output"></param> | |||
/// <returns></returns> | |||
public Functional Model(Tensors inputs, Tensors outputs, string name = null) | |||
=> new Functional(inputs, outputs, name: name); | |||
/// <summary> | |||
/// Instantiate a Keras tensor. | |||
/// </summary> | |||
/// <param name="shape"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="name"></param> | |||
/// <param name="sparse"> | |||
/// A boolean specifying whether the placeholder to be created is sparse. | |||
/// </param> | |||
/// <param name="ragged"> | |||
/// A boolean specifying whether the placeholder to be created is ragged. | |||
/// </param> | |||
/// <param name="tensor"> | |||
/// Optional existing tensor to wrap into the `Input` layer. | |||
/// If set, the layer will not create a placeholder tensor. | |||
/// </param> | |||
/// <returns></returns> | |||
public Tensor Input(TensorShape shape = null, | |||
int batch_size = -1, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false, | |||
Tensor tensor = null) | |||
{ | |||
var args = new InputLayerArgs | |||
{ | |||
Name = name, | |||
InputShape = shape, | |||
BatchSize = batch_size, | |||
DType = dtype, | |||
Sparse = sparse, | |||
Ragged = ragged, | |||
InputTensor = tensor | |||
}; | |||
var layer = new InputLayer(args); | |||
return layer.InboundNodes[0].Outputs; | |||
} | |||
} | |||
} |
@@ -19,7 +19,7 @@ using Tensorflow.Framework.Models; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -50,6 +50,7 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
var prefix = "input"; | |||
name = prefix + '_' + keras.backend.get_uid(prefix); | |||
args.Name = name; | |||
} | |||
if (args.DType == TF_DataType.DtInvalid) | |||
@@ -99,5 +100,8 @@ namespace Tensorflow.Keras.Layers | |||
tf.Context.restore_mode(); | |||
} | |||
public override LayerArgs get_config() | |||
=> args; | |||
} | |||
} |
@@ -2,7 +2,7 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -2,7 +2,7 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -1,4 +1,4 @@ | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras | |||
{ | |||
@@ -35,4 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<Folder Include="Saving\" /> | |||
</ItemGroup> | |||
</Project> |
@@ -21,7 +21,7 @@ using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
@@ -16,11 +16,22 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
public class generic_utils | |||
{ | |||
public static LayerConfig serialize_keras_object(ILayer instance) | |||
{ | |||
var config = instance.get_config(); | |||
return new LayerConfig | |||
{ | |||
Config = config, | |||
ClassName = instance.GetType().Name | |||
}; | |||
} | |||
public static string to_snake_case(string name) | |||
{ | |||
return string.Concat(name.Select((x, i) => | |||
@@ -2,7 +2,7 @@ | |||
<PropertyGroup> | |||
<OutputType>Exe</OutputType> | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<TargetFramework>net5.0</TargetFramework> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
@@ -1,6 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.UnitTest.Keras | |||
{ | |||
@@ -1,6 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.KerasExt; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.UnitTest.Keras | |||
{ | |||
@@ -10,21 +10,20 @@ namespace TensorFlowNET.UnitTest.Keras | |||
[TestClass] | |||
public class ModelSaveTest : EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void SaveAndLoadTest() | |||
[TestMethod, Ignore] | |||
public void GetAndFromConfig() | |||
{ | |||
var model = GetModel(); | |||
var model = GetFunctionalModel(); | |||
var config = model.get_config(); | |||
} | |||
Model GetModel() | |||
Functional GetFunctionalModel() | |||
{ | |||
// Create a simple model. | |||
var inputs = keras.Input(shape: 32); | |||
var dense_layer = keras.layers.Dense(1); | |||
var outputs = dense_layer.Apply(inputs); | |||
var model = keras.Model(inputs, outputs); | |||
model.compile("adam", "mean_squared_error"); | |||
return model; | |||
return keras.Model(inputs, outputs); | |||
} | |||
} | |||
} |