Browse Source

model.get_config for Keras.

tags/v0.30
Oceania2018 4 years ago
parent
commit
dff4f510af
32 changed files with 555 additions and 116 deletions
  1. +1
    -1
      src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
  2. +5
    -1
      src/TensorFlowNET.Core/Keras/Engine/INode.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  4. +16
    -0
      src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
  5. +15
    -0
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  6. +13
    -0
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  7. +25
    -0
      src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
  8. +18
    -0
      src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
  9. +6
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  10. +1
    -0
      src/TensorFlowNET.Core/Protobuf/Gen.bat
  11. +210
    -0
      src/TensorFlowNET.Core/Protobuf/SavedModel.cs
  12. +76
    -0
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  13. +4
    -4
      src/TensorFlowNET.Keras/Engine/Functional.cs
  14. +4
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  15. +1
    -1
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  16. +15
    -2
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  17. +18
    -0
      src/TensorFlowNET.Keras/Engine/Node.Serialize.cs
  18. +1
    -1
      src/TensorFlowNET.Keras/Engine/Node.cs
  19. +1
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  20. +6
    -78
      src/TensorFlowNET.Keras/KerasApi.cs
  21. +0
    -12
      src/TensorFlowNET.Keras/KerasExtension.cs
  22. +84
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  23. +5
    -1
      src/TensorFlowNET.Keras/Layers/InputLayer.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  25. +1
    -1
      src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs
  26. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  27. +4
    -0
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  28. +1
    -1
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
  29. +11
    -0
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  30. +1
    -1
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  31. +1
    -1
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
  32. +7
    -8
      test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs

+ 1
- 1
src/TensorFlowNET.Console/TensorFlowNET.Console.csproj View File

@@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net5.0</TargetFramework>
<RootNamespace>Tensorflow</RootNamespace>
<AssemblyName>Tensorflow</AssemblyName>
</PropertyGroup>


+ 5
- 1
src/TensorFlowNET.Core/Keras/Engine/INode.cs View File

@@ -1,4 +1,6 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Engine
{
@@ -10,5 +12,7 @@ namespace Tensorflow.Keras.Engine
List<Tensor> KerasInputs { get; set; }
INode[] ParentNodes { get; }
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
bool is_input { get; }
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
}
}

+ 2
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras
@@ -14,5 +15,6 @@ namespace Tensorflow.Keras
List<IVariableV1> trainable_variables { get; }
TensorShape output_shape { get; }
int count_params();
LayerArgs get_config();
}
}

+ 16
- 0
src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving
{
public class LayerConfig
{
public string Name { get; set; }
public string ClassName { get; set; }
public LayerArgs Config { get; set; }
public List<INode> InboundNodes { get; set; }
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving
{
public class ModelConfig
{
public string Name { get; set; }
public List<LayerConfig> Layers { get; set; }
public List<ILayer> InputLayers { get; set; }
public List<ILayer> OutputLayers { get; set; }
}
}

+ 13
- 0
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public class NodeConfig
{
public string Name { get; set; }
public int NodeIndex { get; set; }
public int TensorIndex { get; set; }
}
}

+ 25
- 0
src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;

namespace Tensorflow.ModelSaving
{
public class ModelSaver
{
public void save(Trackable obj, string export_dir, SaveOptions options = null)
{
var saved_model = new SavedModel();
var meta_graph_def = new MetaGraphDef();
saved_model.MetaGraphs.Add(meta_graph_def);
_build_meta_graph(obj, export_dir, options, meta_graph_def);
}

void _build_meta_graph(Trackable obj, string export_dir, SaveOptions options,
MetaGraphDef meta_graph_def = null)
{

}
}
}

+ 18
- 0
src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.ModelSaving
{
/// <summary>
/// Options for saving to SavedModel.
/// </summary>
public class SaveOptions
{
bool save_debug_info;
public SaveOptions(bool save_debug_info = false)
{
this.save_debug_info = save_debug_info;
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using Tensorflow.Util;
@@ -132,5 +133,10 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public LayerArgs get_config()
{
throw new NotImplementedException();
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Protobuf/Gen.bat View File

@@ -27,6 +27,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.pro
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_model.proto
ECHO Download `any.proto` from https://github.com/protocolbuffers/protobuf/tree/master/src/google/protobuf
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/meta_graph.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.proto


+ 210
- 0
src/TensorFlowNET.Core/Protobuf/SavedModel.cs View File

@@ -0,0 +1,210 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: tensorflow/core/protobuf/saved_model.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from tensorflow/core/protobuf/saved_model.proto</summary>
public static partial class SavedModelReflection {

#region Descriptor
/// <summary>File descriptor for tensorflow/core/protobuf/saved_model.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static SavedModelReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cip0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvc2F2ZWRfbW9kZWwucHJvdG8S",
"CnRlbnNvcmZsb3caKXRlbnNvcmZsb3cvY29yZS9wcm90b2J1Zi9tZXRhX2dy",
"YXBoLnByb3RvIl8KClNhdmVkTW9kZWwSIgoac2F2ZWRfbW9kZWxfc2NoZW1h",
"X3ZlcnNpb24YASABKAMSLQoLbWV0YV9ncmFwaHMYAiADKAsyGC50ZW5zb3Jm",
"bG93Lk1ldGFHcmFwaERlZkJ7ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC",
"EFNhdmVkTW9kZWxQcm90b3NQAVpIZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl",
"bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2NvcmVfcHJvdG9zX2dvX3By",
"b3Rv+AEBYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Tensorflow.MetaGraphReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedModel), global::Tensorflow.SavedModel.Parser, new[]{ "SavedModelSchemaVersion", "MetaGraphs" }, null, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// SavedModel is the high level serialization format for TensorFlow Models.
/// See [todo: doc links, similar to session_bundle] for more information.
/// </summary>
public sealed partial class SavedModel : pb::IMessage<SavedModel> {
private static readonly pb::MessageParser<SavedModel> _parser = new pb::MessageParser<SavedModel>(() => new SavedModel());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<SavedModel> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.SavedModelReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedModel() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedModel(SavedModel other) : this() {
savedModelSchemaVersion_ = other.savedModelSchemaVersion_;
metaGraphs_ = other.metaGraphs_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedModel Clone() {
return new SavedModel(this);
}

/// <summary>Field number for the "saved_model_schema_version" field.</summary>
public const int SavedModelSchemaVersionFieldNumber = 1;
private long savedModelSchemaVersion_;
/// <summary>
/// The schema version of the SavedModel instance. Used for versioning when
/// making future changes to the specification/implementation. Initial value
/// at release will be 1.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public long SavedModelSchemaVersion {
get { return savedModelSchemaVersion_; }
set {
savedModelSchemaVersion_ = value;
}
}

/// <summary>Field number for the "meta_graphs" field.</summary>
public const int MetaGraphsFieldNumber = 2;
private static readonly pb::FieldCodec<global::Tensorflow.MetaGraphDef> _repeated_metaGraphs_codec
= pb::FieldCodec.ForMessage(18, global::Tensorflow.MetaGraphDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.MetaGraphDef> metaGraphs_ = new pbc::RepeatedField<global::Tensorflow.MetaGraphDef>();
/// <summary>
/// One or more MetaGraphs.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.MetaGraphDef> MetaGraphs {
get { return metaGraphs_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as SavedModel);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(SavedModel other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (SavedModelSchemaVersion != other.SavedModelSchemaVersion) return false;
if(!metaGraphs_.Equals(other.metaGraphs_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (SavedModelSchemaVersion != 0L) hash ^= SavedModelSchemaVersion.GetHashCode();
hash ^= metaGraphs_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (SavedModelSchemaVersion != 0L) {
output.WriteRawTag(8);
output.WriteInt64(SavedModelSchemaVersion);
}
metaGraphs_.WriteTo(output, _repeated_metaGraphs_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (SavedModelSchemaVersion != 0L) {
size += 1 + pb::CodedOutputStream.ComputeInt64Size(SavedModelSchemaVersion);
}
size += metaGraphs_.CalculateSize(_repeated_metaGraphs_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(SavedModel other) {
if (other == null) {
return;
}
if (other.SavedModelSchemaVersion != 0L) {
SavedModelSchemaVersion = other.SavedModelSchemaVersion;
}
metaGraphs_.Add(other.metaGraphs_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
SavedModelSchemaVersion = input.ReadInt64();
break;
}
case 18: {
metaGraphs_.AddEntriesFrom(input, _repeated_metaGraphs_codec);
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 76
- 0
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -0,0 +1,76 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
public ModelConfig get_config()
{
return get_network_config();
}

/// <summary>
/// Builds the config, which consists of the node graph and serialized layers.
/// </summary>
ModelConfig get_network_config()
{
var config = new ModelConfig
{
Name = name
};

var node_conversion_map = new Dictionary<string, int>();
foreach (var layer in _layers)
{
var kept_nodes = _should_skip_first_node(layer) ? 1 : 0;
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key))
{
node_conversion_map[node_key] = kept_nodes;
kept_nodes += 1;
}
}
}

var layer_configs = new List<LayerConfig>();
foreach (var layer in _layers)
{
var filtered_inbound_nodes = new List<INode>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
{
var node_data = node.serialize(_make_node_key, node_conversion_map);
throw new NotImplementedException("");
}
}

var layer_config = generic_utils.serialize_keras_object(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);
}
config.Layers = layer_configs;

return config;
}

bool _should_skip_first_node(ILayer layer)
{
return layer is Functional && layer.Layers[0] is InputLayer;
}

string _make_node_key(string layer_name, int node_index)
=> $"{layer_name}_ib-{node_index}";
}
}

+ 4
- 4
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine
/// <summary>
/// A `Functional` model is a `Model` defined as a directed graph of layers.
/// </summary>
public class Functional : Model
public partial class Functional : Model
{
TensorShape _build_input_shape;
bool _compute_output_and_mask_jointly;
@@ -283,7 +283,7 @@ namespace Tensorflow.Keras.Engine

// Propagate to all previous tensors connected to this node.
nodes_in_progress.Add(node);
if (!node.IsInput)
if (!node.is_input)
{
foreach (var k_tensor in node.KerasInputs)
{
@@ -307,7 +307,7 @@ namespace Tensorflow.Keras.Engine

Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)
{
if (mask != null)
if (mask == null)
{
Tensor[] masks = new Tensor[inputs.Count()];
foreach (var (i, input_t) in enumerate(inputs))
@@ -330,7 +330,7 @@ namespace Tensorflow.Keras.Engine
foreach (Node node in nodes)
{
// Input tensors already exist.
if (node.IsInput)
if (node.is_input)
continue;

var layer_inputs = node.MapArguments(tensor_dict);


+ 4
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;
@@ -241,5 +242,8 @@ namespace Tensorflow.Keras.Engine
return weights;
}
}

public virtual LayerArgs get_config()
=> throw new NotImplementedException("");
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Metrics;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Engine
{


+ 15
- 2
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -1,13 +1,26 @@
using System.Collections.Generic;
using Tensorflow.Keras.Metrics;
using Tensorflow.ModelSaving;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public void save(string path)
ModelSaver saver = new ModelSaver();

/// <summary>
/// Saves the model to Tensorflow SavedModel or a single HDF5 file.
/// </summary>
/// <param name="filepath"></param>
/// <param name="overwrite"></param>
/// <param name="include_optimizer"></param>
public void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
string save_format = "tf",
SaveOptions options = null)
{
saver.save(this, filepath);
}
}
}

+ 18
- 0
src/TensorFlowNET.Keras/Engine/Node.Serialize.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Engine
{
public partial class Node
{
/// <summary>
/// Serializes `Node` for Functional API's `get_config`.
/// </summary>
/// <returns></returns>
public NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map)
{
throw new NotImplementedException("");
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Engine
public TensorShape[] output_shapes;
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>();
public ILayer Layer { get; set; }
public bool IsInput => args.InputTensors == null;
public bool is_input => args.InputTensors == null;
public int[] FlatInputIds { get; set; }
public int[] FlatOutputIds { get; set; }
bool _single_positional_tensor_passed => KerasInputs.Count() == 1;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -17,7 +17,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Engine
{


+ 6
- 78
src/TensorFlowNET.Keras/KerasApi.cs View File

@@ -1,84 +1,12 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras;

namespace Tensorflow.Keras
namespace Tensorflow
{
public class KerasApi
public static class KerasApi
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
public Regularizers regularizers { get; } = new Regularizers();
public LayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();
public static KerasInterface Keras(this tensorflow tf)
=> new KerasInterface();

public Sequential Sequential(List<ILayer> layers = null,
string name = null)
=> new Sequential(new SequentialArgs
{
Layers = layers,
Name = name
});

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Functional Model(Tensors inputs, Tensors outputs, string name = null)
=> new Functional(inputs, outputs, name: name);

/// <summary>
/// Instantiate a Keras tensor.
/// </summary>
/// <param name="shape"></param>
/// <param name="batch_size"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="sparse">
/// A boolean specifying whether the placeholder to be created is sparse.
/// </param>
/// <param name="ragged">
/// A boolean specifying whether the placeholder to be created is ragged.
/// </param>
/// <param name="tensor">
/// Optional existing tensor to wrap into the `Input` layer.
/// If set, the layer will not create a placeholder tensor.
/// </param>
/// <returns></returns>
public Tensor Input(TensorShape shape = null,
int batch_size = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,
Ragged = ragged,
InputTensor = tensor
};

var layer = new InputLayer(args);

return layer.InboundNodes[0].Outputs;
}
public static KerasInterface keras { get; } = new KerasInterface();
}
}

+ 0
- 12
src/TensorFlowNET.Keras/KerasExtension.cs View File

@@ -1,12 +0,0 @@
using Tensorflow.Keras;

namespace Tensorflow
{
public static class KerasExt
{
public static KerasApi Keras(this tensorflow tf)
=> new KerasApi();

public static KerasApi keras { get; } = new KerasApi();
}
}

+ 84
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -0,0 +1,84 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;

namespace Tensorflow.Keras
{
public class KerasInterface
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
public Regularizers regularizers { get; } = new Regularizers();
public LayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();

public Sequential Sequential(List<ILayer> layers = null,
string name = null)
=> new Sequential(new SequentialArgs
{
Layers = layers,
Name = name
});

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Functional Model(Tensors inputs, Tensors outputs, string name = null)
=> new Functional(inputs, outputs, name: name);

/// <summary>
/// Instantiate a Keras tensor.
/// </summary>
/// <param name="shape"></param>
/// <param name="batch_size"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="sparse">
/// A boolean specifying whether the placeholder to be created is sparse.
/// </param>
/// <param name="ragged">
/// A boolean specifying whether the placeholder to be created is ragged.
/// </param>
/// <param name="tensor">
/// Optional existing tensor to wrap into the `Input` layer.
/// If set, the layer will not create a placeholder tensor.
/// </param>
/// <returns></returns>
public Tensor Input(TensorShape shape = null,
int batch_size = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,
Ragged = ragged,
InputTensor = tensor
};

var layer = new InputLayer(args);

return layer.InboundNodes[0].Outputs;
}
}
}

+ 5
- 1
src/TensorFlowNET.Keras/Layers/InputLayer.cs View File

@@ -19,7 +19,7 @@ using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{
@@ -50,6 +50,7 @@ namespace Tensorflow.Keras.Layers
{
var prefix = "input";
name = prefix + '_' + keras.backend.get_uid(prefix);
args.Name = name;
}

if (args.DType == TF_DataType.DtInvalid)
@@ -99,5 +100,8 @@ namespace Tensorflow.Keras.Layers

tf.Context.restore_mode();
}

public override LayerArgs get_config()
=> args;
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -2,7 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/ZeroPadding2D.cs View File

@@ -2,7 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{


+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -1,4 +1,4 @@
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras
{


+ 4
- 0
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -35,4 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

<ItemGroup>
<Folder Include="Saving\" />
</ItemGroup>

</Project>

+ 1
- 1
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -21,7 +21,7 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Utils
{


+ 11
- 0
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -16,11 +16,22 @@

using System;
using System.Linq;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Utils
{
public class generic_utils
{
public static LayerConfig serialize_keras_object(ILayer instance)
{
var config = instance.get_config();
return new LayerConfig
{
Config = config,
ClassName = instance.GetType().Name
};
}

public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>


+ 1
- 1
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net5.0</TargetFramework>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>



+ 1
- 1
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Keras
{


+ 7
- 8
test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs View File

@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Engine;
using static Tensorflow.KerasExt;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Keras
{
@@ -10,21 +10,20 @@ namespace TensorFlowNET.UnitTest.Keras
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
[TestMethod]
public void SaveAndLoadTest()
[TestMethod, Ignore]
public void GetAndFromConfig()
{
var model = GetModel();
var model = GetFunctionalModel();
var config = model.get_config();
}

Model GetModel()
Functional GetFunctionalModel()
{
// Create a simple model.
var inputs = keras.Input(shape: 32);
var dense_layer = keras.layers.Dense(1);
var outputs = dense_layer.Apply(inputs);
var model = keras.Model(inputs, outputs);
model.compile("adam", "mean_squared_error");
return model;
return keras.Model(inputs, outputs);
}
}
}

Loading…
Cancel
Save