Browse Source

Change Config as JObject in KerasMetaData

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
9a4c518401
7 changed files with 51 additions and 8 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
  3. +16
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs
  4. +12
    -3
      src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs
  5. +2
    -1
      src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
  6. +17
    -2
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs

+ 2
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -45,5 +45,7 @@
public IRegularizer ActivityRegularizer { get; set; }

public bool Autocast { get; set; }

public bool IsFromConfig { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs View File

@@ -1,6 +1,6 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class ResizingArgs : LayerArgs
public class ResizingArgs : PreprocessingLayerArgs
{
public int Height { get; set; }
public int Width { get; set; }


+ 16
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class PreprocessingLayer : Layer
{
public PreprocessingLayer(PreprocessingLayerArgs args) : base(args)
{

}
}
}

+ 12
- 3
src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs View File

@@ -1,7 +1,9 @@
using System;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -9,7 +11,7 @@ namespace Tensorflow.Keras.Layers
/// Resize the batched image input to target height and width.
/// The input should be a 4-D tensor in the format of NHWC.
/// </summary>
public class Resizing : Layer
public class Resizing : PreprocessingLayer
{
ResizingArgs args;
public Resizing(ResizingArgs args) : base(args)
@@ -26,5 +28,12 @@ namespace Tensorflow.Keras.Layers
{
return new TensorShape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]);
}

public static Resizing from_config(JObject config)
{
var args = JsonConvert.DeserializeObject<ResizingArgs>(config.ToString());
args.IsFromConfig = true;
return new Resizing(args);
}
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Saving/KerasMetaData.cs View File

@@ -1,4 +1,5 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;
@@ -16,7 +17,7 @@ namespace Tensorflow.Keras.Saving
public int SharedObjectId { get; set; }
[JsonProperty("must_restore_from_config")]
public bool MustRestoreFromConfig { get; set; }
public ModelConfig Config { get; set; }
public JObject Config { get; set; }
[JsonProperty("build_input_shape")]
public TensorShapeConfig BuildInputShape { get; set; }
}


+ 17
- 2
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -5,8 +5,10 @@ using System.Linq;
using System.Text.RegularExpressions;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Saving
{
@@ -73,7 +75,7 @@ namespace Tensorflow.Keras.Saving
{
model = new Sequential(new SequentialArgs
{
Name = config.Name
Name = config.GetValue("name").ToString()
});
}
else if (class_name == "Functional")
@@ -97,7 +99,12 @@ namespace Tensorflow.Keras.Saving
var class_name = metadata.ClassName;
var shared_object_id = metadata.SharedObjectId;
var must_restore_from_config = metadata.MustRestoreFromConfig;

var obj = class_name switch
{
"Resizing" => Resizing.from_config(config),
_ => throw new NotImplementedException("")
};
var built = _try_build_layer(obj, node_id, metadata.BuildInputShape);
return null;
}

@@ -157,5 +164,13 @@ namespace Tensorflow.Keras.Saving

return false;
}

bool _try_build_layer(Layer obj, int node_id, TensorShape build_input_shape)
{
if (obj.Built)
return true;

return false;
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs View File

@@ -10,6 +10,6 @@ namespace Tensorflow.Keras.Saving
public int?[] Items { get; set; }

public static implicit operator TensorShape(TensorShapeConfig shape)
=> new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());
=> shape == null ? null : new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());
}
}

Loading…
Cancel
Save