Browse Source

Merge pull request #1033 from AsakusaRinne/support_bert_load

Change type of BuildInputShape and BatchInputShape
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
c20d854432
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 376 additions and 136 deletions
  1. +23
    -0
      src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Keras/Activations/Activations.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  7. +3
    -3
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs
  9. +3
    -3
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs
  11. +6
    -5
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs
  12. +75
    -0
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs
  13. +3
    -3
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs
  14. +17
    -9
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs
  15. +60
    -0
      src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  18. +1
    -1
      src/TensorFlowNET.Core/NumPy/Axis.cs
  19. +1
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  21. +7
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  22. +1
    -1
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  23. +1
    -5
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  24. +0
    -6
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  25. +2
    -2
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  26. +2
    -2
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  27. +6
    -6
      src/TensorFlowNET.Keras/Engine/Layer.cs
  28. +30
    -10
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  30. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  31. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  32. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  33. +1
    -1
      src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
  34. +4
    -2
      src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
  35. +5
    -3
      src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
  36. +6
    -2
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  37. +4
    -2
      src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
  38. +3
    -2
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  39. +6
    -7
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  40. +2
    -1
      src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
  41. +2
    -1
      src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
  42. +5
    -3
      src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
  43. +5
    -3
      src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
  44. +6
    -4
      src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs
  45. +2
    -2
      src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs
  46. +3
    -2
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  47. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
  48. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs
  49. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs
  50. +5
    -3
      src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
  51. +2
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  52. +5
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  53. +5
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  54. +1
    -1
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  55. +5
    -2
      src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
  56. +16
    -0
      src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs
  57. +7
    -6
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  58. +5
    -0
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
  59. +2
    -2
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  60. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs

+ 23
- 0
src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs View File

@@ -0,0 +1,23 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
{
return default(T);
}
else
{
return res.ToObject<T>();
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Framework.Models
public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) :
base(shape, dtype, name)
{
}

public TensorSpec _unbatch()


+ 1
- 1
src/TensorFlowNET.Core/Keras/Activations/Activations.cs View File

@@ -1,7 +1,7 @@
using Newtonsoft.Json;
using System.Reflection;
using System.Runtime.Versioning;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow.Keras
{


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
@@ -18,7 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}


+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs View File

@@ -1,6 +1,6 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
@@ -17,6 +17,6 @@ namespace Tensorflow.Keras.ArgsDefinition
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// <summary>
/// Only applicable to input layers.
/// </summary>
public virtual Shape BatchInputShape { get; set; }
public virtual KerasShapesWrapper BatchInputShape { get; set; }

public virtual int BatchSize { get; set; } = -1;



+ 3
- 3
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras
string Name { get; }
bool Trainable { get; }
bool Built { get; }
void build(Shape input_shape);
void build(KerasShapesWrapper input_shape);
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
@@ -22,8 +22,8 @@ namespace Tensorflow.Keras
void set_weights(IEnumerable<NDArray> weights);
List<NDArray> get_weights();
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
KerasShapesWrapper BatchInputShape { get; }
KerasShapesWrapper BuildInputShape { get; }
TF_DataType DType { get; }
int count_params();
void adapt(Tensor data, int? batch_size = null, int? steps = null);


src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs View File

@@ -6,7 +6,7 @@ using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedActivationJsonConverter : JsonConverter
{

src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs View File

@@ -4,7 +4,7 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedAxisJsonConverter : JsonConverter
{
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
int[]? axis;
if(reader.ValueType == typeof(long))
if (reader.ValueType == typeof(long))
{
axis = new int[1];
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
@@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Common
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
}
return new Axis((int[])(axis!));
return new Axis(axis!);
}
}
}

src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs View File

@@ -1,7 +1,7 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedDTypeJsonConverter : JsonConverter
{
@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Common

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
var token = JToken.FromObject(((TF_DataType)value).as_numpy_name());
token.WriteTo(writer);
}


src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs View File

@@ -4,9 +4,10 @@ using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

using Tensorflow.Operations.Initializers;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
class InitializerInfo
{
@@ -27,7 +28,7 @@ namespace Tensorflow.Keras.Common
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var initializer = value as IInitializer;
if(initializer is null)
if (initializer is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
@@ -42,7 +43,7 @@ namespace Tensorflow.Keras.Common
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var info = serializer.Deserialize<InitializerInfo>(reader);
if(info is null)
if (info is null)
{
return null;
}
@@ -54,8 +55,8 @@ namespace Tensorflow.Keras.Common
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()),
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(),
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
"RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject<float>(),
maxval: info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(),

+ 75
- 0
src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs View File

@@ -0,0 +1,75 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving.Json
{
public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(KerasShapesWrapper);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
}
if (value is not KerasShapesWrapper wrapper)
{
throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}");
}
if (wrapper.Shapes.Length == 0)
{
JToken.FromObject(null).WriteTo(writer);
}
else if (wrapper.Shapes.Length == 1)
{
JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer);
}
else
{
JToken.FromObject(wrapper.Shapes).WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
if (reader.TokenType == JsonToken.StartArray)
{
TensorShapeConfig[] shapes = serializer.Deserialize<TensorShapeConfig[]>(reader);
if (shapes is null)
{
return null;
}
return new KerasShapesWrapper(shapes);
}
else if (reader.TokenType == JsonToken.StartObject)
{
var shape = serializer.Deserialize<TensorShapeConfig>(reader);
if (shape is null)
{
return null;
}
return new KerasShapesWrapper(shape);
}
else if (reader.TokenType == JsonToken.Null)
{
return null;
}
else
{
throw new ValueError($"Cannot deserialize the token type {reader.TokenType}");
}
}
}
}

src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs View File

@@ -7,7 +7,7 @@ using System.Linq;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedNodeConfigJsonConverter : JsonConverter
{
@@ -46,10 +46,10 @@ namespace Tensorflow.Keras.Common
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length == 1)
if (values.Length == 1)
{
var array = values[0] as JArray;
if(array is null)
if (array is null)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}

src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs → src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs View File

@@ -5,14 +5,14 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
class ShapeInfoFromPython
{
public string class_name { get; set; }
public long?[] items { get; set; }
}
public class CustomizedShapeJsonConverter: JsonConverter
public class CustomizedShapeJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
@@ -25,12 +25,12 @@ namespace Tensorflow.Keras.Common

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if(value is null)
if (value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if(value is not Shape)
else if (value is not Shape)
{
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
}
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common
{
var shape = (value as Shape)!;
long?[] dims = new long?[shape.ndim];
for(int i = 0; i < dims.Length; i++)
for (int i = 0; i < dims.Length; i++)
{
if (shape.dims[i] == -1)
{
@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Common
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
long?[] dims;
try
if (reader.TokenType == JsonToken.StartObject)
{
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
@@ -70,14 +70,22 @@ namespace Tensorflow.Keras.Common
}
dims = shape_info_from_python.items;
}
catch(JsonSerializationException)
else if (reader.TokenType == JsonToken.StartArray)
{
dims = serializer.Deserialize<long?[]>(reader);
}
else if (reader.TokenType == JsonToken.Null)
{
return null;
}
else
{
throw new ValueError($"Cannot deserialize the token {reader} as Shape.");
}
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
for (int i = 0; i < dims.Length; i++)
{
convertedDims[i] = dims[i] ?? (-1);
convertedDims[i] = dims[i] ?? -1;
}
return new Shape(convertedDims);
}

+ 60
- 0
src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs View File

@@ -0,0 +1,60 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using System.Diagnostics;
using OneOf.Types;
using Tensorflow.Keras.Saving.Json;

namespace Tensorflow.Keras.Saving
{
[JsonConverter(typeof(CustomizedKerasShapesWrapperJsonConverter))]
public class KerasShapesWrapper
{
public TensorShapeConfig[] Shapes { get; set; }

public KerasShapesWrapper(Shape shape)
{
Shapes = new TensorShapeConfig[] { shape };
}

public KerasShapesWrapper(TensorShapeConfig shape)
{
Shapes = new TensorShapeConfig[] { shape };
}

public KerasShapesWrapper(TensorShapeConfig[] shapes)
{
Shapes = shapes;
}

public KerasShapesWrapper(IEnumerable<Shape> shape)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
}

public Shape ToSingleShape()
{
Debug.Assert(Shapes.Length == 1);
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public static implicit operator KerasShapesWrapper(Shape shape)
{
return new KerasShapesWrapper(shape);
}
public static implicit operator KerasShapesWrapper(TensorShapeConfig shape)
{
return new KerasShapesWrapper(shape);
}

}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -9,7 +9,7 @@ using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

namespace Tensorflow.Keras.Saving
{
public class ModelConfig : IKerasConfig
public class FunctionalConfig : IKerasConfig
{
[JsonProperty("name")]
public string Name { get; set; }


+ 1
- 1
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -2,7 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow.Keras.Saving
{


+ 1
- 1
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -19,7 +19,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow
{


+ 1
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -19,7 +19,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;
using Tensorflow.NumPy;

namespace Tensorflow


+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -16,7 +16,7 @@

using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow
{


+ 7
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -80,9 +80,9 @@ namespace Tensorflow

public Shape OutputShape => throw new NotImplementedException();

public Shape BatchInputShape => throw new NotImplementedException();
public KerasShapesWrapper BatchInputShape => throw new NotImplementedException();

public TensorShapeConfig BuildInputShape => throw new NotImplementedException();
public KerasShapesWrapper BuildInputShape => throw new NotImplementedException();

public TF_DataType DType => throw new NotImplementedException();
protected bool built = false;
@@ -162,6 +162,11 @@ namespace Tensorflow
throw new NotImplementedException();
}

public void build(KerasShapesWrapper input_shape)
{
throw new NotImplementedException();
}

public Trackable GetTrackable() { throw new NotImplementedException(); }

public void adapt(Tensor data, int? batch_size = null, int? steps = null)


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -1,5 +1,5 @@
using Newtonsoft.Json;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow
{


+ 1
- 5
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -116,12 +116,8 @@ namespace Tensorflow.Training.Saving.SavedModel
}

Dictionary<string, ConcreteFunction> loaded_gradients = new();
// Debug(Rinne)
var temp = _sort_function_defs(library, function_deps);
int i = 0;
foreach (var fdef in temp)
foreach (var fdef in _sort_function_defs(library, function_deps))
{
i++;
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);

object structured_input_signature = null;


+ 0
- 6
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -214,12 +214,6 @@ namespace Tensorflow
continue;
}
var proto = _proto.Nodes[node_id];
if(node_id == 10522)
{
// Debug(Rinne)
Console.WriteLine();
}
var temp = _get_node_dependencies(proto);
foreach (var dep in _get_node_dependencies(proto).Values.Distinct())
{
deps.Add(dep);


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
public static Functional from_config(ModelConfig config)
public static Functional from_config(FunctionalConfig config)
{
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config);
var model = new Functional(input_tensors, output_tensors, name: config.Name);
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null)
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(FunctionalConfig config, Dictionary<string, ILayer>? created_layers = null)
{
// Layer instances created during the graph reconstruction process.
created_layers = created_layers ?? new Dictionary<string, ILayer>();


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -19,9 +19,9 @@ namespace Tensorflow.Keras.Engine
/// <summary>
/// Builds the config, which consists of the node graph and serialized layers.
/// </summary>
ModelConfig get_network_config()
FunctionalConfig get_network_config()
{
var config = new ModelConfig
var config = new FunctionalConfig
{
Name = name
};


+ 6
- 6
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -211,9 +211,9 @@ namespace Tensorflow.Keras.Engine

protected bool computePreviousMask;
protected List<Operation> updates;
public Shape BatchInputShape => args.BatchInputShape;
protected TensorShapeConfig _buildInputShape = null;
public TensorShapeConfig BuildInputShape => _buildInputShape;
public KerasShapesWrapper BatchInputShape => args.BatchInputShape;
protected KerasShapesWrapper _buildInputShape = null;
public KerasShapesWrapper BuildInputShape => _buildInputShape;

List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes;
@@ -284,7 +284,7 @@ namespace Tensorflow.Keras.Engine
// Manage input shape information if passed.
if (args.BatchInputShape == null && args.InputShape != null)
{
args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray();
args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray());
}
}

@@ -363,7 +363,7 @@ namespace Tensorflow.Keras.Engine
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());
}
build(inputs.shape);
build(new KerasShapesWrapper(inputs.shape));

if (need_restore_mode)
tf.Context.restore_mode();
@@ -371,7 +371,7 @@ namespace Tensorflow.Keras.Engine
built = true;
}

public virtual void build(Shape input_shape)
public virtual void build(KerasShapesWrapper input_shape)
{
_buildInputShape = input_shape;
built = true;


+ 30
- 10
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Linq;
using Tensorflow.Graphs;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -8,22 +10,40 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
if (this is Functional || this is Sequential)
if (_is_graph_network || this is Functional || this is Sequential)
{
base.build(input_shape);
return;
}

var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();

graph.as_default();

var x = tf.placeholder(DType, input_shape);
Call(x, training: false);

graph.Exit();
if(input_shape is not null && this.inputs is null)
{
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
graph.as_default();
var shapes = input_shape.ToShapeArray();
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)));
try
{
Call(x, training: false);
}
catch (InvalidArgumentError)
{
throw new ValueError("You cannot build your model by calling `build` " +
"if your layers do not support float type inputs. " +
"Instead, in order to instantiate and build your " +
"model, `call` your model on real tensor data (of the correct dtype).");
}
catch (TypeError)
{
throw new ValueError("You cannot build your model by calling `build` " +
"if your layers do not support float type inputs. " +
"Instead, in order to instantiate and build your " +
"model, `call` your model on real tensor data (of the correct dtype).");
}
graph.Exit();
}

base.build(input_shape);
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -92,7 +92,7 @@ namespace Tensorflow.Keras.Engine
{
// Instantiate an input layer.
var x = keras.Input(
batch_input_shape: layer.BatchInputShape,
batch_input_shape: layer.BatchInputShape.ToSingleShape(),
dtype: layer.DType,
name: layer.Name + "_input");



+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers {
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers {
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
if (alpha < 0f)
{


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers {
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers {
{
// Exponential has no args
}
public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
base.build(input_shape);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers {
@@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers {
public SELU ( LayerArgs args ) : base(args) {
// SELU has no arguments
}
public override void build(Shape input_shape) {
public override void build(KerasShapesWrapper input_shape) {
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Attention/Attention.cs View File

@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Layers
}

// Creates variable when `use_scale` is True or `score_mode` is `concat`.
public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
if (this.use_scale)
this.scale = this.add_weight(name: "scale",


+ 4
- 2
src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs View File

@@ -19,6 +19,7 @@ using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -58,13 +59,14 @@ namespace Tensorflow.Keras.Layers
return args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
if (len(input_shape) != 4)
throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}");

var channel_axis = _get_channel_axis();
var input_dim = input_shape[-1];
var input_dim = single_shape[-1];
var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim);

kernel = add_weight(name: "kernel",


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -57,12 +58,13 @@ namespace Tensorflow.Keras.Layers
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
int channel_axis = data_format == "channels_first" ? 1 : -1;
var single_shape = input_shape.ToSingleShape();
var input_channel = channel_axis < 0 ?
input_shape.dims[input_shape.ndim + channel_axis] :
input_shape.dims[channel_axis];
single_shape.dims[single_shape.ndim + channel_axis] :
single_shape.dims[channel_axis];
Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters });
kernel = add_weight(name: "kernel",
shape: kernel_shape,


+ 6
- 2
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -16,9 +16,11 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -41,10 +43,12 @@ namespace Tensorflow.Keras.Layers
this.inputSpec = new InputSpec(min_ndim: 2);
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
_buildInputShape = input_shape;
var last_dim = input_shape.dims.Last();
Debug.Assert(input_shape.Shapes.Length <= 1);
var single_shape = input_shape.ToSingleShape();
var last_dim = single_shape.dims.Last();
var axes = new Dictionary<int, int>();
axes[-1] = (int)last_dim;
inputSpec = new InputSpec(min_ndim: 2, axes: axes);


+ 4
- 2
src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs View File

@@ -6,6 +6,7 @@ using System.Linq;
using System.Text.RegularExpressions;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -119,9 +120,10 @@ namespace Tensorflow.Keras.Layers
this.bias_constraint = args.BiasConstraint;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape);
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes,
input_shape.ToSingleShape(), this.partial_output_shape);
var kernel_shape = shape_data.Item1;
var bias_shape = shape_data.Item2;
this.full_output_shape = shape_data.Item3;


+ 3
- 2
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -17,6 +17,7 @@
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -48,13 +49,13 @@ namespace Tensorflow.Keras.Layers
args.InputShape = args.InputLength;

if (args.BatchInputShape == null)
args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray();
args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray());

embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer;
SupportsMasking = mask_zero;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
tf.Context.eager_mode();
embeddings = add_weight(shape: (input_dim, output_dim),


+ 6
- 7
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -40,10 +40,10 @@ namespace Tensorflow.Keras.Layers
built = true;
SupportsMasking = true;

if (BatchInputShape != null)
if (BatchInputShape is not null)
{
args.BatchSize = (int)BatchInputShape.dims[0];
args.InputShape = BatchInputShape.dims.Skip(1).ToArray();
args.BatchSize = (int)(BatchInputShape.ToSingleShape().dims[0]);
args.InputShape = BatchInputShape.ToSingleShape().dims.Skip(1).ToArray();
}

// moved to base class
@@ -63,9 +63,8 @@ namespace Tensorflow.Keras.Layers
{
if (args.InputShape != null)
{
args.BatchInputShape = new long[] { args.BatchSize }
.Concat(args.InputShape.dims)
.ToArray();
args.BatchInputShape = new Saving.KerasShapesWrapper(new long[] { args.BatchSize }
.Concat(args.InputShape.dims).ToArray());
}
else
{
@@ -76,7 +75,7 @@ namespace Tensorflow.Keras.Layers
graph.as_default();

args.InputTensor = keras.backend.placeholder(
shape: BatchInputShape,
shape: BatchInputShape.ToSingleShape(),
dtype: DType,
name: Name,
sparse: args.Sparse,


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs View File

@@ -4,6 +4,7 @@ using System.Linq;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
/*var shape_set = new HashSet<Shape>();
var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray();


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Merging/Merge.cs View File

@@ -4,6 +4,7 @@ using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers

}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
// output_shape = input_shape.dims[1^];
_buildInputShape = input_shape;


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

@@ -53,9 +54,10 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.dims.Select(x => (int)x).ToArray();
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var ndims = input_shape.ndim;
var single_shape = input_shape.ToSingleShape();
var ndims = single_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
args.Axis.dims[idx] = axis[idx] = ndims + x;
@@ -74,7 +76,7 @@ namespace Tensorflow.Keras.Layers

var axis_to_dim = new Dictionary<int, int>();
foreach (var x in axis)
axis_to_dim[x] = (int)input_shape[x];
axis_to_dim[x] = (int)single_shape[x];

inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

@@ -49,16 +50,17 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.axis;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var ndims = input_shape.ndim;
var single_shape = input_shape.ToSingleShape();
var ndims = single_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
axis[idx] = ndims + x;

var axis_to_dim = new Dictionary<int, int>();
foreach (var x in axis)
axis_to_dim[x] = (int)input_shape[x];
axis_to_dim[x] = (int)single_shape[x];

inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;


+ 6
- 4
src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -45,10 +46,11 @@ namespace Tensorflow.Keras.Layers
input_variance = args.Variance;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
base.build(input_shape);
var ndim = input_shape.ndim;
var single_shape = input_shape.ToSingleShape();
var ndim = single_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
axis[idx] = ndim + x;
@@ -57,8 +59,8 @@ namespace Tensorflow.Keras.Layers
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray();
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray();
// Broadcast any reduced axes.
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray());
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray();
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? single_shape.dims[d] : 1).ToArray());
var mean_and_var_shape = _keep_axis.Select(d => single_shape.dims[d]).ToArray();

var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
var param_shape = input_shape;


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs View File

@@ -77,8 +77,8 @@ namespace Tensorflow.Keras.Layers
{
var data_shape = data.shape;
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray();
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones);
build(data_shape);
_args.BatchInputShape = BatchInputShape ?? new Saving.KerasShapesWrapper(new Shape(data_shape_nones));
build(new Saving.KerasShapesWrapper(data_shape));
built = true;
}
}


+ 3
- 2
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -35,12 +36,12 @@ namespace Tensorflow.Keras.Layers
var shape = data.output_shapes[0];
if (shape.ndim == 1)
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
build(data.variant_tensor.shape);
build(new KerasShapesWrapper(data.variant_tensor.shape));
var preprocessed_inputs = data.map(_preprocess);
_index_lookup_layer.adapt(preprocessed_inputs);
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
base.build(input_shape);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -11,7 +12,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
if (args.cropping.rank != 1)
{


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
{
this.args = args;
}
public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
built = true;
_buildInputShape = input_shape;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
built = true;
_buildInputShape = input_shape;


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs View File

@@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers {
public class Permute : Layer
@@ -14,14 +15,15 @@ namespace Tensorflow.Keras.Layers {
{
this.dims = args.dims;
}
public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var rank = input_shape.rank;
var single_shape = input_shape.ToSingleShape();
var rank = single_shape.rank;
if (dims.Length != rank - 1)
{
throw new ValueError("Dimensions must match.");
}
permute = new int[input_shape.rank];
permute = new int[single_shape.rank];
dims.CopyTo(permute, 1);
built = true;
_buildInputShape = input_shape;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
@@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers.Rnn
//}
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
if (!cell.Built)
{


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -1,5 +1,6 @@
using System.Data;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Saving;
using Tensorflow.Operations.Activation;
using static HDF.PInvoke.H5Z;
using static Tensorflow.ApiDef.Types;
@@ -14,12 +15,13 @@ namespace Tensorflow.Keras.Layers.Rnn
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var input_dim = input_shape[-1];
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_buildInputShape = input_shape;

kernel = add_weight("kernel", (input_shape[-1], args.Units),
kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -18,11 +19,12 @@ namespace Tensorflow.Keras.Layers.Rnn
this.args = args;
}

public override void build(Shape input_shape)
public override void build(KerasShapesWrapper input_shape)
{
var input_dim = input_shape[-1];
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];

kernel = add_weight("kernel", (input_shape[-1], args.Units),
kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
);



+ 1
- 1
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Models
{
public class ModelsApi: IModelsApi
{
public Functional from_config(ModelConfig config)
public Functional from_config(FunctionalConfig config)
=> Functional.from_config(config);

public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)


+ 5
- 2
src/TensorFlowNET.Keras/Saving/KerasMetaData.cs View File

@@ -22,16 +22,19 @@ namespace Tensorflow.Keras.Saving
public int SharedObjectId { get; set; }
[JsonProperty("must_restore_from_config")]
public bool MustRestoreFromConfig { get; set; }
[JsonProperty("config")]
public JObject Config { get; set; }
[JsonProperty("build_input_shape")]
public TensorShapeConfig BuildInputShape { get; set; }
public KerasShapesWrapper BuildInputShape { get; set; }
[JsonProperty("batch_input_shape")]
public TensorShapeConfig BatchInputShape { get; set; }
public KerasShapesWrapper BatchInputShape { get; set; }
[JsonProperty("activity_regularizer")]
public IRegularizer ActivityRegularizer { get; set; }
[JsonProperty("input_spec")]
public JToken InputSpec { get; set; }
[JsonProperty("stateful")]
public bool? Stateful { get; set; }
[JsonProperty("model_config")]
public KerasModelConfig? ModelConfig { get; set; }
}
}

+ 16
- 0
src/TensorFlowNET.Keras/Saving/KerasModelConfig.cs View File

@@ -0,0 +1,16 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public class KerasModelConfig
{
[JsonProperty("class_name")]
public string ClassName { get; set; }
[JsonProperty("config")]
public JObject Config { get; set; }
}
}

+ 7
- 6
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -8,6 +8,7 @@ using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Tensorflow.Extensions;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
@@ -356,7 +357,7 @@ namespace Tensorflow.Keras.Saving
var (obj, setter) = _revive_from_config(identifier, metadata, node_id);
if (obj is null)
{
(obj, setter) = _revive_custom_object(identifier, metadata);
(obj, setter) = revive_custom_object(identifier, metadata);
}
if(obj is null)
{
@@ -398,7 +399,7 @@ namespace Tensorflow.Keras.Saving
return (obj, setter);
}

private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata)
{
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER)
{
@@ -437,7 +438,7 @@ namespace Tensorflow.Keras.Saving
}
else
{
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>());
model = new Functional(new Tensors(), new Tensors(), config.TryGetOrReturnNull<string>("name"));
}

// Record this model and its layers. This will later be used to reconstruct
@@ -619,7 +620,7 @@ namespace Tensorflow.Keras.Saving
}
}

private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape)
private bool _try_build_layer(Layer obj, int node_id, KerasShapesWrapper build_input_shape)
{
if (obj.Built)
return true;
@@ -679,10 +680,10 @@ namespace Tensorflow.Keras.Saving
return inputs;
}

private Shape _infer_input_shapes(int layer_node_id)
private KerasShapesWrapper _infer_input_shapes(int layer_node_id)
{
var inputs = _infer_inputs(layer_node_id);
return nest.map_structure(x => x.shape, inputs);
return new KerasShapesWrapper(nest.map_structure(x => x.shape, inputs));
}

private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child)


+ 5
- 0
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -173,6 +173,11 @@ namespace Tensorflow.Keras.Utils
obj is not Type;
}

public static Tensor generate_placeholders_from_shape(Shape shape)
{
return array_ops.placeholder(keras.backend.floatx(), shape);
}

// recusive
static bool uses_keras_history(Tensor op_input)
{


+ 2
- 2
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Utils
return args as LayerArgs;
}

public static ModelConfig deserialize_model_config(JToken json)
public static FunctionalConfig deserialize_model_config(JToken json)
{
ModelConfig config = new ModelConfig();
FunctionalConfig config = new FunctionalConfig();
config.Name = json["name"].ToObject<string>();
config.Layers = new List<LayerConfig>();
var layersToken = json["layers"];


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs View File

@@ -18,8 +18,8 @@ namespace TensorFlowNET.Keras.UnitTest
{
var model = GetFunctionalModel();
var config = model.get_config();
Debug.Assert(config is ModelConfig);
var new_model = new ModelsApi().from_config(config as ModelConfig);
Debug.Assert(config is FunctionalConfig);
var new_model = new ModelsApi().from_config(config as FunctionalConfig);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
}



Loading…
Cancel
Save