Browse Source

Merge pull request #1027 from AsakusaRinne/fix_1013

Fix the error when loading VGG19.
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
6340054944
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 26 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  2. +13
    -21
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  4. +9
    -2
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  5. +6
    -0
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  6. +24
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// This class has nothing but the attributes different from `LayerArgs`.
/// It's used to serialize the model to `tf` format.
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`,
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`.
/// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`.
/// </summary>
public class AutoSerializeLayerArgs: LayerArgs
{


+ 13
- 21
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -7,6 +7,11 @@ using System.Text;

namespace Tensorflow.Keras.Common
{
class ShapeInfoFromPython
{
public string class_name { get; set; }
public long?[] items { get; set; }
}
public class CustomizedShapeJsonConverter: JsonConverter
{
public override bool CanConvert(Type objectType)
@@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common
dims[i] = shape.dims[i];
}
}
var token = JToken.FromObject(dims);
var token = JToken.FromObject(new ShapeInfoFromPython()
{
class_name = "__tuple__",
items = dims
});
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
long?[] dims;
try
{
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
catch (JsonSerializationException ex)
{
if (reader.Value.Equals("class_name"))
{
reader.Read();
reader.Read();
reader.Read();
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
else
{
throw ex;
}
}
if (dims is null)
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
{
return null;
}
long ?[]dims = shape_info_from_python.items;
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -108,7 +108,7 @@ https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />


+ 9
- 2
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -563,7 +563,7 @@ namespace Tensorflow
return proto.KindCase switch
{
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, dependencies),
SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(),
@@ -626,7 +626,7 @@ namespace Tensorflow
}

private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto,
Dictionary<OneOf<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, Trackable> dependencies)
{
var fn = function_deserialization.recreate_function(proto, _concrete_functions);
foreach (var name in proto.ConcreteFunctions)
@@ -644,6 +644,13 @@ namespace Tensorflow
return (fn, setattr);
}

private (Tensor, Action<object, object, object>) _get_tensor_from_fn(CapturedTensor proto)
{
var outer_graph = _concrete_functions[proto.ConcreteFunction].func_graph;
var captured_tensor = outer_graph.get_tensor_by_name(proto.Name);
return (captured_tensor, setattr);
}

// TODO: remove this to a common class.
public static Action<object, object, object> setattr = (x, y, z) =>
{


+ 6
- 0
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -71,6 +71,9 @@ namespace Tensorflow.Keras.Utils
var args = deserializationGenericMethod.Invoke(config, null);
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);

// TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

return layer as Layer;
}

@@ -82,6 +85,9 @@ namespace Tensorflow.Keras.Utils
return null;
}
Debug.Assert(layer is Layer);

// TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

return layer as Layer;
}



+ 24
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -6,13 +6,13 @@ using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

[TestClass]
public class SequentialModelLoad
{
[Ignore]
[TestMethod]
public void SimpleModelFromAutoCompile()
{
@@ -80,4 +80,27 @@ public class SequentialModelLoad

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}

[Ignore]
[TestMethod]
public void VGG19()
{
var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19");
model.summary();

var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>()
{
model,
keras.layers.Flatten(),
keras.layers.Dense(10),
});
classify_model.summary();

classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var x = np.random.uniform(0, 1, (8, 512, 512, 3));
var y = np.ones((8));

classify_model.fit(x, y, batch_size: 4);
}
}

Loading…
Cancel
Save