Change type of BuildInputShape and BatchInputShapetags/v0.100.5-BERT-load
@@ -0,0 +1,23 @@ | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Extensions | |||
{ | |||
public static class JObjectExtensions | |||
{ | |||
public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||
{ | |||
var res = obj[key]; | |||
if(res is null) | |||
{ | |||
return default(T); | |||
} | |||
else | |||
{ | |||
return res.ToObject<T>(); | |||
} | |||
} | |||
} | |||
} |
@@ -7,7 +7,7 @@ namespace Tensorflow.Framework.Models | |||
public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : | |||
base(shape, dtype, name) | |||
{ | |||
} | |||
public TensorSpec _unbatch() | |||
@@ -1,7 +1,7 @@ | |||
using Newtonsoft.Json; | |||
using System.Reflection; | |||
using System.Runtime.Versioning; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
namespace Tensorflow.Keras | |||
{ | |||
@@ -2,6 +2,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
@@ -18,7 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
[JsonProperty("dtype")] | |||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||
[JsonProperty("trainable")] | |||
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||
} | |||
@@ -1,6 +1,6 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Serialization; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
@@ -17,6 +17,6 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
[JsonProperty("dtype")] | |||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||
} | |||
} |
@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
/// <summary> | |||
/// Only applicable to input layers. | |||
/// </summary> | |||
public virtual Shape BatchInputShape { get; set; } | |||
public virtual KerasShapesWrapper BatchInputShape { get; set; } | |||
public virtual int BatchSize { get; set; } = -1; | |||
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras | |||
string Name { get; } | |||
bool Trainable { get; } | |||
bool Built { get; } | |||
void build(Shape input_shape); | |||
void build(KerasShapesWrapper input_shape); | |||
List<ILayer> Layers { get; } | |||
List<INode> InboundNodes { get; } | |||
List<INode> OutboundNodes { get; } | |||
@@ -22,8 +22,8 @@ namespace Tensorflow.Keras | |||
void set_weights(IEnumerable<NDArray> weights); | |||
List<NDArray> get_weights(); | |||
Shape OutputShape { get; } | |||
Shape BatchInputShape { get; } | |||
TensorShapeConfig BuildInputShape { get; } | |||
KerasShapesWrapper BatchInputShape { get; } | |||
KerasShapesWrapper BuildInputShape { get; } | |||
TF_DataType DType { get; } | |||
int count_params(); | |||
void adapt(Tensor data, int? batch_size = null, int? steps = null); | |||
@@ -6,7 +6,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
public class CustomizedActivationJsonConverter : JsonConverter | |||
{ |
@@ -4,7 +4,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
public class CustomizedAxisJsonConverter : JsonConverter | |||
{ | |||
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
int[]? axis; | |||
if(reader.ValueType == typeof(long)) | |||
if (reader.ValueType == typeof(long)) | |||
{ | |||
axis = new int[1]; | |||
axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||
@@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Common | |||
{ | |||
throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||
} | |||
return new Axis((int[])(axis!)); | |||
return new Axis(axis!); | |||
} | |||
} | |||
} |
@@ -1,7 +1,7 @@ | |||
using Newtonsoft.Json.Linq; | |||
using Newtonsoft.Json; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
public class CustomizedDTypeJsonConverter : JsonConverter | |||
{ | |||
@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Common | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); | |||
var token = JToken.FromObject(((TF_DataType)value).as_numpy_name()); | |||
token.WriteTo(writer); | |||
} | |||
@@ -4,9 +4,10 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.Initializers; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
class InitializerInfo | |||
{ | |||
@@ -27,7 +28,7 @@ namespace Tensorflow.Keras.Common | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
var initializer = value as IInitializer; | |||
if(initializer is null) | |||
if (initializer is null) | |||
{ | |||
JToken.FromObject(null).WriteTo(writer); | |||
return; | |||
@@ -42,7 +43,7 @@ namespace Tensorflow.Keras.Common | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
var info = serializer.Deserialize<InitializerInfo>(reader); | |||
if(info is null) | |||
if (info is null) | |||
{ | |||
return null; | |||
} | |||
@@ -54,8 +55,8 @@ namespace Tensorflow.Keras.Common | |||
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()), | |||
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | |||
info.config["seed"].ToObject<int?>()), | |||
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(), | |||
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()), | |||
"RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject<float>(), | |||
maxval: info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()), | |||
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | |||
info.config["seed"].ToObject<int?>()), | |||
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(), |
@@ -0,0 +1,75 @@ | |||
using Newtonsoft.Json.Linq; | |||
using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Saving.Json | |||
{ | |||
public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
{ | |||
return objectType == typeof(KerasShapesWrapper); | |||
} | |||
public override bool CanRead => true; | |||
public override bool CanWrite => true; | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
if (value is null) | |||
{ | |||
JToken.FromObject(null).WriteTo(writer); | |||
return; | |||
} | |||
if (value is not KerasShapesWrapper wrapper) | |||
{ | |||
throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}"); | |||
} | |||
if (wrapper.Shapes.Length == 0) | |||
{ | |||
JToken.FromObject(null).WriteTo(writer); | |||
} | |||
else if (wrapper.Shapes.Length == 1) | |||
{ | |||
JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer); | |||
} | |||
else | |||
{ | |||
JToken.FromObject(wrapper.Shapes).WriteTo(writer); | |||
} | |||
} | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
if (reader.TokenType == JsonToken.StartArray) | |||
{ | |||
TensorShapeConfig[] shapes = serializer.Deserialize<TensorShapeConfig[]>(reader); | |||
if (shapes is null) | |||
{ | |||
return null; | |||
} | |||
return new KerasShapesWrapper(shapes); | |||
} | |||
else if (reader.TokenType == JsonToken.StartObject) | |||
{ | |||
var shape = serializer.Deserialize<TensorShapeConfig>(reader); | |||
if (shape is null) | |||
{ | |||
return null; | |||
} | |||
return new KerasShapesWrapper(shape); | |||
} | |||
else if (reader.TokenType == JsonToken.Null) | |||
{ | |||
return null; | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Cannot deserialize the token type {reader.TokenType}"); | |||
} | |||
} | |||
} | |||
} |
@@ -7,7 +7,7 @@ using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
public class CustomizedNodeConfigJsonConverter : JsonConverter | |||
{ | |||
@@ -46,10 +46,10 @@ namespace Tensorflow.Keras.Common | |||
{ | |||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
} | |||
if(values.Length == 1) | |||
if (values.Length == 1) | |||
{ | |||
var array = values[0] as JArray; | |||
if(array is null) | |||
if (array is null) | |||
{ | |||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
} |
@@ -5,14 +5,14 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Common | |||
namespace Tensorflow.Keras.Saving.Common | |||
{ | |||
class ShapeInfoFromPython | |||
{ | |||
public string class_name { get; set; } | |||
public long?[] items { get; set; } | |||
} | |||
public class CustomizedShapeJsonConverter: JsonConverter | |||
public class CustomizedShapeJsonConverter : JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
{ | |||
@@ -25,12 +25,12 @@ namespace Tensorflow.Keras.Common | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
if(value is null) | |||
if (value is null) | |||
{ | |||
var token = JToken.FromObject(null); | |||
token.WriteTo(writer); | |||
} | |||
else if(value is not Shape) | |||
else if (value is not Shape) | |||
{ | |||
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | |||
} | |||
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common | |||
{ | |||
var shape = (value as Shape)!; | |||
long?[] dims = new long?[shape.ndim]; | |||
for(int i = 0; i < dims.Length; i++) | |||
for (int i = 0; i < dims.Length; i++) | |||
{ | |||
if (shape.dims[i] == -1) | |||
{ | |||
@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Common | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
long?[] dims; | |||
try | |||
if (reader.TokenType == JsonToken.StartObject) | |||
{ | |||
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | |||
if (shape_info_from_python is null) | |||
@@ -70,14 +70,22 @@ namespace Tensorflow.Keras.Common | |||
} | |||
dims = shape_info_from_python.items; | |||
} | |||
catch(JsonSerializationException) | |||
else if (reader.TokenType == JsonToken.StartArray) | |||
{ | |||
dims = serializer.Deserialize<long?[]>(reader); | |||
} | |||
else if (reader.TokenType == JsonToken.Null) | |||
{ | |||
return null; | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Cannot deserialize the token {reader} as Shape."); | |||
} | |||
long[] convertedDims = new long[dims.Length]; | |||
for(int i = 0; i < dims.Length; i++) | |||
for (int i = 0; i < dims.Length; i++) | |||
{ | |||
convertedDims[i] = dims[i] ?? (-1); | |||
convertedDims[i] = dims[i] ?? -1; | |||
} | |||
return new Shape(convertedDims); | |||
} |
@@ -0,0 +1,60 @@ | |||
using Newtonsoft.Json.Linq; | |||
using Newtonsoft.Json; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Diagnostics; | |||
using OneOf.Types; | |||
using Tensorflow.Keras.Saving.Json; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
[JsonConverter(typeof(CustomizedKerasShapesWrapperJsonConverter))] | |||
public class KerasShapesWrapper | |||
{ | |||
public TensorShapeConfig[] Shapes { get; set; } | |||
public KerasShapesWrapper(Shape shape) | |||
{ | |||
Shapes = new TensorShapeConfig[] { shape }; | |||
} | |||
public KerasShapesWrapper(TensorShapeConfig shape) | |||
{ | |||
Shapes = new TensorShapeConfig[] { shape }; | |||
} | |||
public KerasShapesWrapper(TensorShapeConfig[] shapes) | |||
{ | |||
Shapes = shapes; | |||
} | |||
public KerasShapesWrapper(IEnumerable<Shape> shape) | |||
{ | |||
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); | |||
} | |||
public Shape ToSingleShape() | |||
{ | |||
Debug.Assert(Shapes.Length == 1); | |||
var shape_config = Shapes[0]; | |||
Debug.Assert(shape_config is not null); | |||
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); | |||
} | |||
public Shape[] ToShapeArray() | |||
{ | |||
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||
} | |||
public static implicit operator KerasShapesWrapper(Shape shape) | |||
{ | |||
return new KerasShapesWrapper(shape); | |||
} | |||
public static implicit operator KerasShapesWrapper(TensorShapeConfig shape) | |||
{ | |||
return new KerasShapesWrapper(shape); | |||
} | |||
} | |||
} |
@@ -9,7 +9,7 @@ using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class ModelConfig : IKerasConfig | |||
public class FunctionalConfig : IKerasConfig | |||
{ | |||
[JsonProperty("name")] | |||
public string Name { get; set; } | |||
@@ -2,7 +2,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
@@ -19,7 +19,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
namespace Tensorflow | |||
{ | |||
@@ -19,7 +19,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow | |||
@@ -16,7 +16,7 @@ | |||
using Newtonsoft.Json; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
namespace Tensorflow | |||
{ | |||
@@ -80,9 +80,9 @@ namespace Tensorflow | |||
public Shape OutputShape => throw new NotImplementedException(); | |||
public Shape BatchInputShape => throw new NotImplementedException(); | |||
public KerasShapesWrapper BatchInputShape => throw new NotImplementedException(); | |||
public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); | |||
public KerasShapesWrapper BuildInputShape => throw new NotImplementedException(); | |||
public TF_DataType DType => throw new NotImplementedException(); | |||
protected bool built = false; | |||
@@ -162,6 +162,11 @@ namespace Tensorflow | |||
throw new NotImplementedException(); | |||
} | |||
public void build(KerasShapesWrapper input_shape) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Trackable GetTrackable() { throw new NotImplementedException(); } | |||
public void adapt(Tensor data, int? batch_size = null, int? steps = null) | |||
@@ -1,5 +1,5 @@ | |||
using Newtonsoft.Json; | |||
using Tensorflow.Keras.Common; | |||
using Tensorflow.Keras.Saving.Common; | |||
namespace Tensorflow | |||
{ | |||
@@ -116,12 +116,8 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
} | |||
Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
// Debug(Rinne) | |||
var temp = _sort_function_defs(library, function_deps); | |||
int i = 0; | |||
foreach (var fdef in temp) | |||
foreach (var fdef in _sort_function_defs(library, function_deps)) | |||
{ | |||
i++; | |||
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
object structured_input_signature = null; | |||
@@ -214,12 +214,6 @@ namespace Tensorflow | |||
continue; | |||
} | |||
var proto = _proto.Nodes[node_id]; | |||
if(node_id == 10522) | |||
{ | |||
// Debug(Rinne) | |||
Console.WriteLine(); | |||
} | |||
var temp = _get_node_dependencies(proto); | |||
foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
{ | |||
deps.Add(dep); | |||
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Functional | |||
{ | |||
public static Functional from_config(ModelConfig config) | |||
public static Functional from_config(FunctionalConfig config) | |||
{ | |||
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); | |||
var model = new Functional(input_tensors, output_tensors, name: config.Name); | |||
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
/// <param name="config"></param> | |||
/// <returns></returns> | |||
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(FunctionalConfig config, Dictionary<string, ILayer>? created_layers = null) | |||
{ | |||
// Layer instances created during the graph reconstruction process. | |||
created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||
@@ -19,9 +19,9 @@ namespace Tensorflow.Keras.Engine | |||
/// <summary> | |||
/// Builds the config, which consists of the node graph and serialized layers. | |||
/// </summary> | |||
ModelConfig get_network_config() | |||
FunctionalConfig get_network_config() | |||
{ | |||
var config = new ModelConfig | |||
var config = new FunctionalConfig | |||
{ | |||
Name = name | |||
}; | |||
@@ -211,9 +211,9 @@ namespace Tensorflow.Keras.Engine | |||
protected bool computePreviousMask; | |||
protected List<Operation> updates; | |||
public Shape BatchInputShape => args.BatchInputShape; | |||
protected TensorShapeConfig _buildInputShape = null; | |||
public TensorShapeConfig BuildInputShape => _buildInputShape; | |||
public KerasShapesWrapper BatchInputShape => args.BatchInputShape; | |||
protected KerasShapesWrapper _buildInputShape = null; | |||
public KerasShapesWrapper BuildInputShape => _buildInputShape; | |||
List<INode> inboundNodes; | |||
public List<INode> InboundNodes => inboundNodes; | |||
@@ -284,7 +284,7 @@ namespace Tensorflow.Keras.Engine | |||
// Manage input shape information if passed. | |||
if (args.BatchInputShape == null && args.InputShape != null) | |||
{ | |||
args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||
args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); | |||
} | |||
} | |||
@@ -363,7 +363,7 @@ namespace Tensorflow.Keras.Engine | |||
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | |||
} | |||
build(inputs.shape); | |||
build(new KerasShapesWrapper(inputs.shape)); | |||
if (need_restore_mode) | |||
tf.Context.restore_mode(); | |||
@@ -371,7 +371,7 @@ namespace Tensorflow.Keras.Engine | |||
built = true; | |||
} | |||
public virtual void build(Shape input_shape) | |||
public virtual void build(KerasShapesWrapper input_shape) | |||
{ | |||
_buildInputShape = input_shape; | |||
built = true; | |||
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
@@ -8,22 +10,40 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
if (this is Functional || this is Sequential) | |||
if (_is_graph_network || this is Functional || this is Sequential) | |||
{ | |||
base.build(input_shape); | |||
return; | |||
} | |||
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||
graph.as_default(); | |||
var x = tf.placeholder(DType, input_shape); | |||
Call(x, training: false); | |||
graph.Exit(); | |||
if(input_shape is not null && this.inputs is null) | |||
{ | |||
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||
graph.as_default(); | |||
var shapes = input_shape.ToShapeArray(); | |||
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); | |||
try | |||
{ | |||
Call(x, training: false); | |||
} | |||
catch (InvalidArgumentError) | |||
{ | |||
throw new ValueError("You cannot build your model by calling `build` " + | |||
"if your layers do not support float type inputs. " + | |||
"Instead, in order to instantiate and build your " + | |||
"model, `call` your model on real tensor data (of the correct dtype)."); | |||
} | |||
catch (TypeError) | |||
{ | |||
throw new ValueError("You cannot build your model by calling `build` " + | |||
"if your layers do not support float type inputs. " + | |||
"Instead, in order to instantiate and build your " + | |||
"model, `call` your model on real tensor data (of the correct dtype)."); | |||
} | |||
graph.Exit(); | |||
} | |||
base.build(input_shape); | |||
} | |||
@@ -92,7 +92,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
// Instantiate an input layer. | |||
var x = keras.Input( | |||
batch_input_shape: layer.BatchInputShape, | |||
batch_input_shape: layer.BatchInputShape.ToSingleShape(), | |||
dtype: layer.DType, | |||
name: layer.Name + "_input"); | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers { | |||
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers { | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
if (alpha < 0f) | |||
{ | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers { | |||
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers { | |||
{ | |||
// Exponential has no args | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
base.build(input_shape); | |||
} | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers { | |||
@@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers { | |||
public SELU ( LayerArgs args ) : base(args) { | |||
// SELU has no arguments | |||
} | |||
public override void build(Shape input_shape) { | |||
public override void build(KerasShapesWrapper input_shape) { | |||
if ( alpha < 0f ) { | |||
throw new ValueError("Alpha must be a number greater than 0."); | |||
} | |||
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Layers | |||
} | |||
// Creates variable when `use_scale` is True or `score_mode` is `concat`. | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
if (this.use_scale) | |||
this.scale = this.add_weight(name: "scale", | |||
@@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -58,13 +59,14 @@ namespace Tensorflow.Keras.Layers | |||
return args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var single_shape = input_shape.ToSingleShape(); | |||
if (len(input_shape) != 4) | |||
throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); | |||
var channel_axis = _get_channel_axis(); | |||
var input_dim = input_shape[-1]; | |||
var input_dim = single_shape[-1]; | |||
var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim); | |||
kernel = add_weight(name: "kernel", | |||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
@@ -57,12 +58,13 @@ namespace Tensorflow.Keras.Layers | |||
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
int channel_axis = data_format == "channels_first" ? 1 : -1; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var input_channel = channel_axis < 0 ? | |||
input_shape.dims[input_shape.ndim + channel_axis] : | |||
input_shape.dims[channel_axis]; | |||
single_shape.dims[single_shape.ndim + channel_axis] : | |||
single_shape.dims[channel_axis]; | |||
Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters }); | |||
kernel = add_weight(name: "kernel", | |||
shape: kernel_shape, | |||
@@ -16,9 +16,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
@@ -41,10 +43,12 @@ namespace Tensorflow.Keras.Layers | |||
this.inputSpec = new InputSpec(min_ndim: 2); | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
_buildInputShape = input_shape; | |||
var last_dim = input_shape.dims.Last(); | |||
Debug.Assert(input_shape.Shapes.Length <= 1); | |||
var single_shape = input_shape.ToSingleShape(); | |||
var last_dim = single_shape.dims.Last(); | |||
var axes = new Dictionary<int, int>(); | |||
axes[-1] = (int)last_dim; | |||
inputSpec = new InputSpec(min_ndim: 2, axes: axes); | |||
@@ -6,6 +6,7 @@ using System.Linq; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.ArgsDefinition.Core; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -119,9 +120,10 @@ namespace Tensorflow.Keras.Layers | |||
this.bias_constraint = args.BiasConstraint; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape); | |||
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, | |||
input_shape.ToSingleShape(), this.partial_output_shape); | |||
var kernel_shape = shape_data.Item1; | |||
var bias_shape = shape_data.Item2; | |||
this.full_output_shape = shape_data.Item3; | |||
@@ -17,6 +17,7 @@ | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
@@ -48,13 +49,13 @@ namespace Tensorflow.Keras.Layers | |||
args.InputShape = args.InputLength; | |||
if (args.BatchInputShape == null) | |||
args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||
args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); | |||
embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer; | |||
SupportsMasking = mask_zero; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
tf.Context.eager_mode(); | |||
embeddings = add_weight(shape: (input_dim, output_dim), | |||
@@ -40,10 +40,10 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
SupportsMasking = true; | |||
if (BatchInputShape != null) | |||
if (BatchInputShape is not null) | |||
{ | |||
args.BatchSize = (int)BatchInputShape.dims[0]; | |||
args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); | |||
args.BatchSize = (int)(BatchInputShape.ToSingleShape().dims[0]); | |||
args.InputShape = BatchInputShape.ToSingleShape().dims.Skip(1).ToArray(); | |||
} | |||
// moved to base class | |||
@@ -63,9 +63,8 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
if (args.InputShape != null) | |||
{ | |||
args.BatchInputShape = new long[] { args.BatchSize } | |||
.Concat(args.InputShape.dims) | |||
.ToArray(); | |||
args.BatchInputShape = new Saving.KerasShapesWrapper(new long[] { args.BatchSize } | |||
.Concat(args.InputShape.dims).ToArray()); | |||
} | |||
else | |||
{ | |||
@@ -76,7 +75,7 @@ namespace Tensorflow.Keras.Layers | |||
graph.as_default(); | |||
args.InputTensor = keras.backend.placeholder( | |||
shape: BatchInputShape, | |||
shape: BatchInputShape.ToSingleShape(), | |||
dtype: DType, | |||
name: Name, | |||
sparse: args.Sparse, | |||
@@ -4,6 +4,7 @@ using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
@@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
/*var shape_set = new HashSet<Shape>(); | |||
var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); | |||
@@ -4,6 +4,7 @@ using System.Text; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
// output_shape = input_shape.dims[1^]; | |||
_buildInputShape = input_shape; | |||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
@@ -53,9 +54,10 @@ namespace Tensorflow.Keras.Layers | |||
axis = args.Axis.dims.Select(x => (int)x).ToArray(); | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var ndims = input_shape.ndim; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var ndims = single_shape.ndim; | |||
foreach (var (idx, x) in enumerate(axis)) | |||
if (x < 0) | |||
args.Axis.dims[idx] = axis[idx] = ndims + x; | |||
@@ -74,7 +76,7 @@ namespace Tensorflow.Keras.Layers | |||
var axis_to_dim = new Dictionary<int, int>(); | |||
foreach (var x in axis) | |||
axis_to_dim[x] = (int)input_shape[x]; | |||
axis_to_dim[x] = (int)single_shape[x]; | |||
inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | |||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
@@ -49,16 +50,17 @@ namespace Tensorflow.Keras.Layers | |||
axis = args.Axis.axis; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var ndims = input_shape.ndim; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var ndims = single_shape.ndim; | |||
foreach (var (idx, x) in enumerate(axis)) | |||
if (x < 0) | |||
axis[idx] = ndims + x; | |||
var axis_to_dim = new Dictionary<int, int>(); | |||
foreach (var x in axis) | |||
axis_to_dim[x] = (int)input_shape[x]; | |||
axis_to_dim[x] = (int)single_shape[x]; | |||
inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | |||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -45,10 +46,11 @@ namespace Tensorflow.Keras.Layers | |||
input_variance = args.Variance; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
base.build(input_shape); | |||
var ndim = input_shape.ndim; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var ndim = single_shape.ndim; | |||
foreach (var (idx, x) in enumerate(axis)) | |||
if (x < 0) | |||
axis[idx] = ndim + x; | |||
@@ -57,8 +59,8 @@ namespace Tensorflow.Keras.Layers | |||
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | |||
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | |||
// Broadcast any reduced axes. | |||
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); | |||
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); | |||
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? single_shape.dims[d] : 1).ToArray()); | |||
var mean_and_var_shape = _keep_axis.Select(d => single_shape.dims[d]).ToArray(); | |||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
var param_shape = input_shape; | |||
@@ -77,8 +77,8 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
var data_shape = data.shape; | |||
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | |||
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); | |||
build(data_shape); | |||
_args.BatchInputShape = BatchInputShape ?? new Saving.KerasShapesWrapper(new Shape(data_shape_nones)); | |||
build(new Saving.KerasShapesWrapper(data_shape)); | |||
built = true; | |||
} | |||
} | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
@@ -35,12 +36,12 @@ namespace Tensorflow.Keras.Layers | |||
var shape = data.output_shapes[0]; | |||
if (shape.ndim == 1) | |||
data = data.map(tensor => array_ops.expand_dims(tensor, -1)); | |||
build(data.variant_tensor.shape); | |||
build(new KerasShapesWrapper(data.variant_tensor.shape)); | |||
var preprocessed_inputs = data.map(_preprocess); | |||
_index_lookup_layer.adapt(preprocessed_inputs); | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
base.build(input_shape); | |||
} | |||
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers.Reshaping | |||
{ | |||
@@ -11,7 +12,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
if (args.cropping.rank != 1) | |||
{ | |||
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers.Reshaping | |||
{ | |||
@@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
{ | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
built = true; | |||
_buildInputShape = input_shape; | |||
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers.Reshaping | |||
{ | |||
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
built = true; | |||
_buildInputShape = input_shape; | |||
@@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers { | |||
public class Permute : Layer | |||
@@ -14,14 +15,15 @@ namespace Tensorflow.Keras.Layers { | |||
{ | |||
this.dims = args.dims; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var rank = input_shape.rank; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var rank = single_shape.rank; | |||
if (dims.Length != rank - 1) | |||
{ | |||
throw new ValueError("Dimensions must match."); | |||
} | |||
permute = new int[input_shape.rank]; | |||
permute = new int[single_shape.rank]; | |||
dims.CopyTo(permute, 1); | |||
built = true; | |||
_buildInputShape = input_shape; | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
@@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
//} | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
if (!cell.Built) | |||
{ | |||
@@ -1,5 +1,6 @@ | |||
using System.Data; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Operations.Activation; | |||
using static HDF.PInvoke.H5Z; | |||
using static Tensorflow.ApiDef.Types; | |||
@@ -14,12 +15,13 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var input_dim = input_shape[-1]; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var input_dim = single_shape[-1]; | |||
_buildInputShape = input_shape; | |||
kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||
kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||
initializer: args.KernelInitializer | |||
//regularizer = self.kernel_regularizer, | |||
//constraint = self.kernel_constraint, | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
@@ -18,11 +19,12 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
this.args = args; | |||
} | |||
public override void build(Shape input_shape) | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
var input_dim = input_shape[-1]; | |||
var single_shape = input_shape.ToSingleShape(); | |||
var input_dim = single_shape[-1]; | |||
kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||
kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||
initializer: args.KernelInitializer | |||
); | |||
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Models | |||
{ | |||
public class ModelsApi: IModelsApi | |||
{ | |||
public Functional from_config(ModelConfig config) | |||
public Functional from_config(FunctionalConfig config) | |||
=> Functional.from_config(config); | |||
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
@@ -22,16 +22,19 @@ namespace Tensorflow.Keras.Saving | |||
public int SharedObjectId { get; set; } | |||
[JsonProperty("must_restore_from_config")] | |||
public bool MustRestoreFromConfig { get; set; } | |||
[JsonProperty("config")] | |||
public JObject Config { get; set; } | |||
[JsonProperty("build_input_shape")] | |||
public TensorShapeConfig BuildInputShape { get; set; } | |||
public KerasShapesWrapper BuildInputShape { get; set; } | |||
[JsonProperty("batch_input_shape")] | |||
public TensorShapeConfig BatchInputShape { get; set; } | |||
public KerasShapesWrapper BatchInputShape { get; set; } | |||
[JsonProperty("activity_regularizer")] | |||
public IRegularizer ActivityRegularizer { get; set; } | |||
[JsonProperty("input_spec")] | |||
public JToken InputSpec { get; set; } | |||
[JsonProperty("stateful")] | |||
public bool? Stateful { get; set; } | |||
[JsonProperty("model_config")] | |||
public KerasModelConfig? ModelConfig { get; set; } | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class KerasModelConfig | |||
{ | |||
[JsonProperty("class_name")] | |||
public string ClassName { get; set; } | |||
[JsonProperty("config")] | |||
public JObject Config { get; set; } | |||
} | |||
} |
@@ -8,6 +8,7 @@ using System.Diagnostics; | |||
using System.Linq; | |||
using System.Reflection; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Extensions; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
@@ -356,7 +357,7 @@ namespace Tensorflow.Keras.Saving | |||
var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||
if (obj is null) | |||
{ | |||
(obj, setter) = _revive_custom_object(identifier, metadata); | |||
(obj, setter) = revive_custom_object(identifier, metadata); | |||
} | |||
if(obj is null) | |||
{ | |||
@@ -398,7 +399,7 @@ namespace Tensorflow.Keras.Saving | |||
return (obj, setter); | |||
} | |||
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | |||
{ | |||
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
{ | |||
@@ -437,7 +438,7 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
else | |||
{ | |||
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||
model = new Functional(new Tensors(), new Tensors(), config.TryGetOrReturnNull<string>("name")); | |||
} | |||
// Record this model and its layers. This will later be used to reconstruct | |||
@@ -619,7 +620,7 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
} | |||
private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
private bool _try_build_layer(Layer obj, int node_id, KerasShapesWrapper build_input_shape) | |||
{ | |||
if (obj.Built) | |||
return true; | |||
@@ -679,10 +680,10 @@ namespace Tensorflow.Keras.Saving | |||
return inputs; | |||
} | |||
private Shape _infer_input_shapes(int layer_node_id) | |||
private KerasShapesWrapper _infer_input_shapes(int layer_node_id) | |||
{ | |||
var inputs = _infer_inputs(layer_node_id); | |||
return nest.map_structure(x => x.shape, inputs); | |||
return new KerasShapesWrapper(nest.map_structure(x => x.shape, inputs)); | |||
} | |||
private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
@@ -173,6 +173,11 @@ namespace Tensorflow.Keras.Utils | |||
obj is not Type; | |||
} | |||
public static Tensor generate_placeholders_from_shape(Shape shape) | |||
{ | |||
return array_ops.placeholder(keras.backend.floatx(), shape); | |||
} | |||
// recusive | |||
static bool uses_keras_history(Tensor op_input) | |||
{ | |||
@@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Utils | |||
return args as LayerArgs; | |||
} | |||
public static ModelConfig deserialize_model_config(JToken json) | |||
public static FunctionalConfig deserialize_model_config(JToken json) | |||
{ | |||
ModelConfig config = new ModelConfig(); | |||
FunctionalConfig config = new FunctionalConfig(); | |||
config.Name = json["name"].ToObject<string>(); | |||
config.Layers = new List<LayerConfig>(); | |||
var layersToken = json["layers"]; | |||
@@ -18,8 +18,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||
{ | |||
var model = GetFunctionalModel(); | |||
var config = model.get_config(); | |||
Debug.Assert(config is ModelConfig); | |||
var new_model = new ModelsApi().from_config(config as ModelConfig); | |||
Debug.Assert(config is FunctionalConfig); | |||
var new_model = new ModelsApi().from_config(config as FunctionalConfig); | |||
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | |||
} | |||