Fix the error when loading VGG19.tags/v0.100.5-BERT-load
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
/// This class has nothing but the attributes different from `LayerArgs`. | |||
/// It's used to serialize the model to `tf` format. | |||
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | |||
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. | |||
/// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. | |||
/// </summary> | |||
public class AutoSerializeLayerArgs: LayerArgs | |||
{ | |||
@@ -7,6 +7,11 @@ using System.Text; | |||
namespace Tensorflow.Keras.Common | |||
{ | |||
class ShapeInfoFromPython | |||
{ | |||
public string class_name { get; set; } | |||
public long?[] items { get; set; } | |||
} | |||
public class CustomizedShapeJsonConverter: JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
@@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common | |||
dims[i] = shape.dims[i]; | |||
} | |||
} | |||
var token = JToken.FromObject(dims); | |||
var token = JToken.FromObject(new ShapeInfoFromPython() | |||
{ | |||
class_name = "__tuple__", | |||
items = dims | |||
}); | |||
token.WriteTo(writer); | |||
} | |||
} | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
long?[] dims; | |||
try | |||
{ | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
catch (JsonSerializationException ex) | |||
{ | |||
if (reader.Value.Equals("class_name")) | |||
{ | |||
reader.Read(); | |||
reader.Read(); | |||
reader.Read(); | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
else | |||
{ | |||
throw ex; | |||
} | |||
} | |||
if (dims is null) | |||
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | |||
if (shape_info_from_python is null) | |||
{ | |||
return null; | |||
} | |||
long ?[]dims = shape_info_from_python.items; | |||
long[] convertedDims = new long[dims.Length]; | |||
for(int i = 0; i < dims.Length; i++) | |||
{ | |||
@@ -108,7 +108,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
<ItemGroup> | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | |||
<PackageReference Include="OneOf" Version="3.0.223" /> | |||
<PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
@@ -563,7 +563,7 @@ namespace Tensorflow | |||
return proto.KindCase switch | |||
{ | |||
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | |||
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, dependencies), | |||
SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), | |||
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | |||
@@ -626,7 +626,7 @@ namespace Tensorflow | |||
} | |||
private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
Dictionary<OneOf<string, int>, Trackable> dependencies) | |||
IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
{ | |||
var fn = function_deserialization.recreate_function(proto, _concrete_functions); | |||
foreach (var name in proto.ConcreteFunctions) | |||
@@ -644,6 +644,13 @@ namespace Tensorflow | |||
return (fn, setattr); | |||
} | |||
private (Tensor, Action<object, object, object>) _get_tensor_from_fn(CapturedTensor proto) | |||
{ | |||
var outer_graph = _concrete_functions[proto.ConcreteFunction].func_graph; | |||
var captured_tensor = outer_graph.get_tensor_by_name(proto.Name); | |||
return (captured_tensor, setattr); | |||
} | |||
// TODO: remove this to a common class. | |||
public static Action<object, object, object> setattr = (x, y, z) => | |||
{ | |||
@@ -71,6 +71,9 @@ namespace Tensorflow.Keras.Utils | |||
var args = deserializationGenericMethod.Invoke(config, null); | |||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
Debug.Assert(layer is Layer); | |||
// TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) | |||
return layer as Layer; | |||
} | |||
@@ -82,6 +85,9 @@ namespace Tensorflow.Keras.Utils | |||
return null; | |||
} | |||
Debug.Assert(layer is Layer); | |||
// TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) | |||
return layer as Layer; | |||
} | |||
@@ -6,13 +6,13 @@ using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.UnitTest.Helpers; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
[TestClass] | |||
public class SequentialModelLoad | |||
{ | |||
[Ignore] | |||
[TestMethod] | |||
public void SimpleModelFromAutoCompile() | |||
{ | |||
@@ -80,4 +80,27 @@ public class SequentialModelLoad | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
} | |||
[Ignore] | |||
[TestMethod] | |||
public void VGG19() | |||
{ | |||
var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); | |||
model.summary(); | |||
var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>() | |||
{ | |||
model, | |||
keras.layers.Flatten(), | |||
keras.layers.Dense(10), | |||
}); | |||
classify_model.summary(); | |||
classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
var x = np.random.uniform(0, 1, (8, 512, 512, 3)); | |||
var y = np.ones((8)); | |||
classify_model.fit(x, y, batch_size: 4); | |||
} | |||
} |