Browse Source

Support loading of SavedModel format (#989)

* Add CheckpointReader and corresponding C APIs.

* Add essential components of SavedModel format loading.

* Add checkpoint reading for SavedModel format loading.

* Revise customized json converters.

* Add support for loading models from python.

* Fix the duplicated weights in Keras.Model.

* Add alexnet loading test and check for loaded weights.

* Fix ci error caused by branch merge.

* Resolve the comments and errors.

* Fix the stucking of training when loading model.

* Fix the stucking of training when loading model.

* fix intptr.

---------

Co-authored-by: Haiping Chen <haiping008@gmail.com>
tags/v0.100.4-load-saved-model
Yaohui Liu GitHub 2 years ago
parent
commit
52b513d750
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 3118 additions and 209 deletions
  1. +18
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +100
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  4. +27
    -0
      src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
  5. +378
    -0
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  7. +331
    -0
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  8. +5
    -1
      src/TensorFlowNET.Core/Eager/execute.cs
  9. +13
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  10. +12
    -0
      src/TensorFlowNET.Core/IO/gfile.cs
  11. +10
    -1
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  12. +36
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  13. +32
    -5
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  14. +21
    -3
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  15. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  16. +3
    -0
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  17. +1
    -0
      src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
  18. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  19. +42
    -1
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  20. +1
    -0
      src/TensorFlowNET.Core/Operations/io_ops.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  22. +5
    -1
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  23. +3
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  24. +18
    -0
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  25. +23
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs
  26. +8
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  27. +6
    -6
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs
  28. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  29. +22
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs
  30. +36
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  31. +641
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  32. +122
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs
  33. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  34. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs
  35. +79
    -15
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  36. +43
    -10
      src/TensorFlowNET.Core/Training/Trackable.cs
  37. +4
    -3
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  38. +9
    -5
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  39. +18
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  40. +3
    -8
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  41. +13
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  42. +26
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Layers.cs
  43. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  44. +70
    -42
      src/TensorFlowNET.Keras/Engine/Layer.cs
  45. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  46. +38
    -7
      src/TensorFlowNET.Keras/Engine/Model.cs
  47. +10
    -5
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  48. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  49. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  50. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  51. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  52. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  53. +1
    -1
      src/TensorFlowNET.Keras/Metrics/Metric.cs
  54. +3
    -13
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  55. +527
    -34
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  56. +3
    -3
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  57. +96
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
  58. +69
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs
  59. +68
    -0
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  60. +1
    -1
      src/TensorFlowNET.Keras/Utils/layer_utils.cs
  61. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy
  62. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb
  63. +9
    -0
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb
  64. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy
  65. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb
  66. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001
  67. BIN
      test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index
  68. +68
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  69. +10
    -14
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
  70. +24
    -0
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 18
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -149,4 +149,22 @@ public static class CheckPointUtils
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// } // }
} }

/// <summary>
/// Traverse the object graph and list all accessible objects.
/// </summary>
/// <param name="object_graph_view"></param>
public static IList<Trackable> list_objects(ObjectGraphView graph_view)
{
return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
}

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;
});
}
} }

+ 100
- 0
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -0,0 +1,100 @@
using Tensorflow.Util;

namespace Tensorflow.Checkpoint
{
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
{
public SafeCheckpointReaderHandle(): base()
{
}
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
{

}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
public class CheckpointReader
{
private SafeCheckpointReaderHandle _handle;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
}

/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
}

public int Size()
{
return c_api.TF_CheckpointReaderSize(_handle);
}

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
}

public Shape GetVariableShape(string name)
{
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
}

public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
status.Check(true);
return new Tensor(tensor);
}

private void ReadAllShapeAndType()
{
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
int size = Size();
for(int i = 0; i < size; i++)
{
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
}
}
}

+ 3
- 3
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -175,9 +175,9 @@ public static class SaveUtilV1
{ {
var name = factory_data.name; var name = factory_data.name;
var key = factory_data.checkpoint_key; var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);


// TODO: oneflow python has a process with callable `saveable_factory`.
// TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new(); List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{ {
@@ -217,7 +217,7 @@ public static class SaveUtilV1


public record class CheckpointFactoryData public record class CheckpointFactoryData
( (
Maybe<BaseResourceVariable, MySaveableObject> factory,
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
string name, string name,
string checkpoint_key string checkpoint_key
); );

+ 27
- 0
src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs View File

@@ -0,0 +1,27 @@
using System.Runtime.InteropServices;
using Tensorflow.Checkpoint;

namespace Tensorflow
{
public unsafe partial class c_api
{
[DllImport(TensorFlowLibName)]
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
}
}

+ 378
- 0
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -6,8 +6,12 @@ using System.Linq;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Operations;
using Newtonsoft.Json;
using Tensorflow.Training;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


@@ -21,8 +25,20 @@ public class TrackableSaver
private TrackableObjectGraph _last_save_object_graph; private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null; private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null; private Tensor? _file_prefix_feed_tensor = null;
private Tensor? _file_prefix_placeholder = null;
private Dictionary<Trackable, Trackable>? _object_map = null; private Dictionary<Trackable, Trackable>? _object_map = null;
private object? _cache = null; private object? _cache = null;
public Tensor? FilePrefixPlaceHolder
{
get
{
return _file_prefix_placeholder;
}
set
{
_file_prefix_placeholder = value;
}
}
public TrackableSaver(ObjectGraphView graph_view) public TrackableSaver(ObjectGraphView graph_view)
{ {
_graph_view = graph_view; _graph_view = graph_view;
@@ -192,4 +208,366 @@ public class TrackableSaver
return save_path; return save_path;
} }
} }

public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}
if(save_path is null)
{
return new InitializationOnlyStatus(_graph_view, ops.uid());
}

CheckpointReader reader = new CheckpointReader(save_path);
bool graph_building = tf.Context.executing_eagerly();
Dictionary<string, TF_DataType> dtype_map = null;
if (!graph_building)
{
dtype_map = reader.VariableToDataTypeMap;
}
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);

Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor;
if (graph_building)
{
if(_file_prefix_placeholder is null)
{
tf.device("/cpu:0");
_file_prefix_placeholder = constant_op.constant("model");
}
file_prefix_tensor = _file_prefix_placeholder;
file_prefix_feed_dict = new();
file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
}
else
{
tf.device("/cpu:0");
file_prefix_tensor = constant_op.constant(save_path);
file_prefix_feed_dict = null;
}
TrackableObjectGraph object_graph_proto = new();
if(object_graph_string.ndim > 0)
{
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
}
else
{
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]);
}
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
object_graph_proto: object_graph_proto,
save_path: save_path,
save_path_tensor: file_prefix_tensor,
reader: reader,
restore_op_cache: null,
graph_view: _graph_view,
options: options,
saveables_cache: null
);

new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root);

if(_graph_view.AttachedDependencies is not null)
{
foreach(var refer in _graph_view.AttachedDependencies)
{
if(refer.Name == "root")
{
continue;
}
int? proto_id = null;
// Find proto ID of attached dependency (if it is in the proto).
foreach (var proto_refer in object_graph_proto.Nodes[0].Children)
{
if(proto_refer.LocalName == refer.Name)
{
proto_id = proto_refer.NodeId;
break;
}
}

if (proto_id is null)
{
continue;
}

// Object has already been restored. This can happen when there's an
// indirect connection from the attached object to the root.
if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value))
{
continue;
}

new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer);
}
}

return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view);
}
}

public class CheckpointRestoreCoordinator
{
private CheckpointOptions _options;
private TrackableObjectGraph _object_graph_proto;
private int _restore_uid;
private HashSet<int> _matched_proto_ids;
private Tensor _save_path_tensor;
private string _save_path_string;
private CheckpointReader _reader;
private Dictionary<string, TF_DataType> _dtype_map;
private Dictionary<string, Shape> _shape_map;
private ObjectGraphView _graph_view;
private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations;
private bool _expect_partial_attr;
private List<Operation> _restore_ops;
private List<Trackable> _all_trackables;
private Dictionary<int, Trackable> _object_by_proto_id;
private Dictionary<string, Operation> _restore_ops_by_name;
private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations;
private Dictionary<int, IList<string>> _unused_attributes;

public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor,
CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache)
{
// TODO(Rinne): cache.
_options = options;
_object_graph_proto = object_graph_proto;
_restore_uid = ops.uid();
_save_path_tensor = save_path_tensor;
_save_path_string = save_path;
_reader = reader;
if(_reader is null)
{
_reader = new CheckpointReader(save_path);
}
_dtype_map = _reader.VariableToDataTypeMap;
_shape_map = _reader.VariableToShapeMap;
_graph_view = graph_view;
_restore_ops = new List<Operation>();
_restore_ops_by_name = new Dictionary<string, Operation>();
_all_trackables = new List<Trackable>();
_matched_proto_ids = new HashSet<int>();
_object_by_proto_id = new Dictionary<int, Trackable>();
_slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>();
_deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>();

_expect_partial_attr = false;
for(int i = 0; i < _object_graph_proto.Nodes.Count; i++)
{
var node = _object_graph_proto.Nodes[i];
foreach(var slot_reference in node.SlotVariables)
{
_slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>())
.Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName));
}
}

// skip the deleter and cache.
}

public bool ExpectPartial
{
get
{
return _expect_partial_attr;
}
set
{
_expect_partial_attr = value;
}
}

/// <summary>
/// Corresponding to `all_python_objects` of tensorflow python
/// </summary>
public List<Trackable> AllTrackables => _all_trackables;
public HashSet<int> MatchedProtoIds => _matched_proto_ids;
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id;
public int RestoreUid => _restore_uid;
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto;
public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations;
public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations;
public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name;
public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes;

public void new_restore_ops(IEnumerable<Operation> new_ops)
{
_restore_ops.AddRange(new_ops);
// skip the callback.
}

public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
{
List<Operation> restore_ops = new();
foreach(var position in positions)
{
var key = position.ObjectProto.Attributes[0].CheckpointKey;
throw new NotImplementedException();
}

Dictionary<string, BaseResourceVariable> variable_dict = new();
foreach(var item in tensor_saveables)
{
if(item.Value.TryGet<BaseResourceVariable>(out var variable))
{
variable_dict[item.Key] = variable;
}
else
{
throw new TypeError();
}
}

if (tensor_saveables is not null && tensor_saveables.Count > 0)
{
var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict);
var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options);
if (!tf.Context.executing_eagerly())
{
foreach(var item in new_restore_ops)
{
restore_ops.Add(item.Value);
Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key));
_restore_ops_by_name[item.Key] = item.Value;
}
}
}
return restore_ops;
}
}

public abstract class LoadStatus
{
public abstract LoadStatus assert_consumed();
public abstract LoadStatus assert_existing_objects_matched();
public abstract LoadStatus assert_nontrivial_match();
public abstract LoadStatus run_restore_ops(Session? session = null);
public abstract void initialize_or_restore(Session? session = null);
public virtual LoadStatus expect_partial()
{
return this;
}
}

public class InitializationOnlyStatus: LoadStatus
{
private int _restore_uid;
private ObjectGraphView _object_graph_view;
private Trackable _root;
public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid)
{
_restore_uid = restore_uid;
_object_graph_view = object_graph_view;
_root = object_graph_view.Root;
}
public override LoadStatus assert_consumed()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override LoadStatus assert_existing_objects_matched()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override LoadStatus assert_nontrivial_match()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override LoadStatus run_restore_ops(Session? session = null)
{
throw new AssertionError("No checkpoint specified, so no restore ops are available "
+ "(save_path=None to Saver.restore).");
}
public override void initialize_or_restore(Session? session = null)
{
if (tf.Context.executing_eagerly())
{
return;
}
if(session is null)
{
session = new Session();
}
var trackable_objects = CheckPointUtils.list_objects(_object_graph_view);
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}

internal class CheckpointLoadStatus: LoadStatus
{
private CheckpointRestoreCoordinator _checkpoint;
private Dictionary<Tensor, string> _feed_dict;
private ObjectGraphView _object_graph_view;
private Trackable _root;
public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base()
{
_checkpoint = checkpoint;
_feed_dict = feed_dict;
_object_graph_view = graph_view;
_root = graph_view.Root;
}

public CheckpointRestoreCoordinator Checkpoint => _checkpoint;

public override LoadStatus assert_consumed()
{
throw new NotImplementedException();
}

public override LoadStatus assert_existing_objects_matched()
{
for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++)
{
var node = _checkpoint.ObjectGraphProto.Nodes[i];
if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) &&
trackable.UpdateUid < _checkpoint.RestoreUid)
{
throw new AssertionError($"Object {node} not assigned a value from checkpoint.");
}
}
foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view))
{
if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0)
{
continue;
}
_checkpoint.AllTrackables.Add(trackable_object);
}
var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables)
.Except(_checkpoint.ObjectByProtoId.Values);
if (unused_trackables.Any())
{
var num_unused_trackables = unused_trackables.Count();
var num_variables_to_show = Math.Min(10, num_unused_trackables);
throw new AssertionError($"Found {num_unused_trackables} Python objects that were " +
$"not bound to checkpointed values, likely due to changes in the " +
$"Python program. Showing {num_variables_to_show} of " +
$"{num_unused_trackables} unmatched objects: " +
$"{{list(unused_python_objects)[:num_variables_to_show]}}");
}
return this;
}

public override LoadStatus assert_nontrivial_match()
{
throw new NotImplementedException();
}

public override LoadStatus expect_partial()
{
throw new NotImplementedException();
}

public override void initialize_or_restore(Session? session = null)
{
throw new NotImplementedException();
}

public override LoadStatus run_restore_ops(Session? session = null)
{
throw new NotImplementedException();
}
} }

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint


// tf python has code `with ops.device(restore_device):` here. // tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky. tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());


Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0; int idx = 0;


+ 331
- 0
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -0,0 +1,331 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow.Checkpoint;

public class CheckpointPosition
{
private CheckpointRestoreCoordinator _checkpoint;
private int _proto_id;
private bool _skip_restore;
public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id)
{
_checkpoint = checkpoint;
_proto_id = proto_id;
_skip_restore = false;
}

public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id];
public CheckpointRestoreCoordinator Checkpoint => _checkpoint;
public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id];

public void restore(Trackable trackable)
{
using (ops.init_scope())
{
if (bind_project(trackable))
{
var restore_ops = _restore_descendants();
if(restore_ops is not null && restore_ops.Count > 0)
{
_checkpoint.new_restore_ops(restore_ops);
}
}
}
}

/// <summary>
/// Set a checkpoint<->object correspondence.
/// </summary>
/// <param name="trackable"></param>
/// <returns></returns>
public bool bind_project(Trackable trackable)
{
_checkpoint.AllTrackables.Add(trackable);
_checkpoint.MatchedProtoIds.Add(_proto_id);
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
{
// skip the `logging.warning`.
return false;
}
else
{
_checkpoint.ObjectByProtoId[_proto_id] = trackable;
return true;
}
}

public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
{
// skip the registered_saver

if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
}

var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable);

List<Operation> existing_restore_ops;
List<CheckpointPosition> positions = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables;
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories);
}
else if(saveable_factories.Count > 0)
{
(existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories);
}
else
{
throw new NotImplementedException();
}
return (existing_restore_ops, named_saveables, positions, null);
}

public CheckpointPosition create_child_position(int node_id)
{
return new CheckpointPosition(_checkpoint, node_id);
}

public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable,
int slot_variable_id, string slot_name)
{
//CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id);

// TODO(Rinne): implement it.
return (null, null);
}

/// <summary>
/// Creates a saveable using the _serialize_to_tensor method.
/// </summary>
/// <param name="saveable_factories"></param>
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
string suffix = SaveableCompat.get_saveable_name(this.Trackable);
suffix = suffix ?? "";
var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix;

if (!tf.Context.executing_eagerly())
{
throw new NotImplementedException("The restore under graph mode has not been implemented. " +
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}

var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name);
// skip the cache.
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new();
dict[saveable_name] = saveable;
return (new List<Operation>(), dict);
}

private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
// TODO(Rinne): implement it.
if(ObjectProto.Attributes is null)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>());
}

List<Operation> existing_restore_ops = new();
HashSet<string> created_compat_names = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new();
foreach (var serialized_tensor in ObjectProto.Attributes)
{
Operation existing_op;
if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey))
{
existing_op = null;
}
else
{
existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey];
}

if(existing_op is not null)
{
existing_restore_ops.Add(existing_op);
continue;
}

if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x)))
{
continue;
}

// TODO(Rinne): deal with cache.

var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names);
if(saveable is null)
{
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name);
continue;
}
named_saveables[serialized_tensor.CheckpointKey] = saveable;
}
return (existing_restore_ops, named_saveables);
}

private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories,
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names)
{
var expected_factory_name = serialized_tensor.Name;
var factory_input_name = serialized_tensor.CheckpointKey;

if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory))
{
foreach(var item in saveable_factories)
{
var factory_name = item.Key;
var factory = item.Value;
if (expected_factory_name.StartsWith(factory_name))
{
if(matched_factory is not null)
{
throw new ValueError($"Forward compatibility load error: Unable to load " +
"checkpoint saved in future version of TensorFlow. " +
"Please update your version of TensorFlow to the " +
"version in which the checkpoint was saved.");
}
}
matched_factory = factory;
factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name;
created_compat_names.Add(factory_name);
}
}
return matched_factory(factory_input_name);
}

private string _extract_saveable_name(string checkpoint_key)
{
var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/";
return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length);
}

/// <summary>
/// Restore the bound Trackable and dependencies (may be deferred).
/// </summary>
private List<Operation> _restore_descendants()
{
Queue<(CheckpointPosition, Trackable)> visit_queue = new();
visit_queue.Enqueue((this, this.Trackable));
List<Operation> restore_ops = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
List<CheckpointPosition> positions = new();

CheckpointPosition current_position = null;
while (visit_queue.Count > 0)
{
current_position = visit_queue.Dequeue().Item1;
var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore();
restore_ops.AddRange(new_restore_ops);
foreach(var item in new_tensor_saveables)
{
tensor_saveables.Add(item.Key, item.Value);
}
positions.AddRange(new_positions);
_queue_children_for_restoration(current_position, visit_queue);
_queue_slot_variables(current_position, visit_queue);
}
restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null));
return restore_ops;
}

private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
{
var trackable = checkpoint_position.Trackable;
foreach(var child in checkpoint_position.ObjectProto.Children)
{
var child_position = checkpoint_position.create_child_position(child.NodeId);
var local_object = trackable._lookup_dependency(child.LocalName);
var child_proto = child_position.ObjectProto;
if(local_object is null)
{
if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any())
{
trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position);
}
}
else
{
if (child_position.bind_project(local_object))
{
visit_queue.Enqueue((child_position, local_object));
}
}
}
}

private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
{
var trackable = checkpoint_position.Trackable;
var checkpoint = checkpoint_position.Checkpoint;
if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions))
{
checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id);
foreach (var deferred_slot_restoration in positions)
{
var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position(
trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId,
deferred_slot_restoration.SlotName
);
if(slot_variable_position is not null)
{
visit_queue.Enqueue((slot_variable_position, slot_variable));
}
}
}
if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations))
{
checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id);
foreach (var slot_restoration in restorations)
{
if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object))
{
throw new NotImplementedException();
// TODO(Rinne); implement it.
}
else
{
Debug.Assert(trackable is BaseResourceVariable);
Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>())
.Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName));
}
}
}
}

private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
{
var trackable = this.Trackable;
trackable._maybe_initialize_trackable();
if(_checkpoint.RestoreUid > trackable.UpdateUid)
{
var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables();
trackable.UpdateUid = _checkpoint.RestoreUid;
return (restore_ops, tensor_saveables, positions, registered_savers);
}
else
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
}
}
}

public record class DeferredSlotVariableRestoration(
BaseResourceVariable OriginalVariable,
int SlotVariableId,
string SlotName
);

+ 5
- 1
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -10,7 +10,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
internal class execute
internal static class execute
{ {
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{ {
@@ -27,5 +27,9 @@ namespace Tensorflow.Eager


return tensors; return tensors;
} }
public static bool must_record_gradient()
{
return false;
}
} }
} }

+ 13
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -13,8 +13,8 @@ namespace Tensorflow.Functions
/// </summary> /// </summary>
public class ConcreteFunction: Trackable public class ConcreteFunction: Trackable
{ {
FuncGraph func_graph;
ForwardBackwardCall forward_backward;
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs; public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures; public Tensor[] CapturedInputs => func_graph.external_captures;


@@ -23,6 +23,8 @@ namespace Tensorflow.Functions
public Tensor[] Outputs; public Tensor[] Outputs;
public Type ReturnType; public Type ReturnType;
public TensorSpec[] OutputStructure; public TensorSpec[] OutputStructure;
public IEnumerable<string> ArgKeywords { get; set; }
public long NumPositionArgs { get; set; }


public ConcreteFunction(string name) public ConcreteFunction(string name)
{ {
@@ -163,6 +165,15 @@ namespace Tensorflow.Functions
return flat_outputs; return flat_outputs;
} }


public void AddTograph(Graph? g = null)
{
if(!tf.Context.executing_eagerly() && g is null)
{
g = ops.get_default_graph();
}
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{ {
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); var functions = new FirstOrderTapeGradientFunctions(func_graph, false);


+ 12
- 0
src/TensorFlowNET.Core/IO/gfile.cs View File

@@ -16,8 +16,10 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using static Tensorflow.Binding;


namespace Tensorflow.IO namespace Tensorflow.IO
{ {
@@ -63,5 +65,15 @@ namespace Tensorflow.IO
dirs.AddRange(Directory.GetFiles(dir)); dirs.AddRange(Directory.GetFiles(dir));
return dirs.ToArray(); return dirs.ToArray();
} }

public string join(params string[] paths)
{
Debug.Assert(paths.Length >= 1);
if (paths[0].Substring(1).Contains("://"))
{
throw new NotImplementedException("The combination of urls has not been implemented.");
}
return Path.Combine(paths);
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common


public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{ {
var axis = serializer.Deserialize(reader, typeof(long[]));
int[]? axis;
if(reader.ValueType == typeof(long))
{
axis = new int[1];
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
}
else
{
axis = serializer.Deserialize(reader, typeof(int[])) as int[];
}
if (axis is null) if (axis is null)
{ {
throw new ValueError("Cannot deserialize 'null' to `Axis`."); throw new ValueError("Cannot deserialize 'null' to `Axis`.");


+ 36
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -0,0 +1,36 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;

namespace Tensorflow.Keras.Common
{
public class CustomizedDTypeJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(TF_DataType);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
token.WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
if (reader.ValueType == typeof(string))
{
var str = (string)serializer.Deserialize(reader, typeof(string));
return dtypes.tf_dtype_from_name(str);
}
else
{
return (TF_DataType)serializer.Deserialize(reader, typeof(int));
}
}
}
}

+ 32
- 5
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common
{ {
throw new ValueError("Cannot deserialize 'null' to `Shape`."); throw new ValueError("Cannot deserialize 'null' to `Shape`.");
} }
if(values.Length != 3)
if(values.Length == 1)
{
var array = values[0] as JArray;
if(array is null)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
values = array.ToObject<object[]>();
}
if (values.Length < 3)
{ {
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
} }
@@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
} }
if (values[1] is not int)
int nodeIndex;
int tensorIndex;
if (values[1] is long)
{
nodeIndex = (int)(long)values[1];
}
else if (values[1] is int)
{
nodeIndex = (int)values[1];
}
else
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
} }
if (values[2] is not int)
if (values[2] is long)
{
tensorIndex = (int)(long)values[2];
}
else if (values[1] is int)
{
tensorIndex = (int)values[2];
}
else
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
} }
return new NodeConfig() return new NodeConfig()
{ {
Name = values[0] as string, Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
NodeIndex = nodeIndex,
TensorIndex = tensorIndex
}; };
} }
} }


+ 21
- 3
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -51,10 +51,28 @@ namespace Tensorflow.Keras.Common


public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{ {
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
if(dims is null)
long?[] dims;
try
{ {
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
catch (JsonSerializationException ex)
{
if (reader.Value.Equals("class_name"))
{
reader.Read();
reader.Read();
reader.Read();
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
}
else
{
throw ex;
}
}
if (dims is null)
{
return null;
} }
long[] convertedDims = new long[dims.Length]; long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++) for(int i = 0; i < dims.Length; i++)


+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -19,6 +19,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableVariables { get; } List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; } List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; }
Shape OutputShape { get; } Shape OutputShape { get; }
Shape BatchInputShape { get; } Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; } TensorShapeConfig BuildInputShape { get; }


+ 3
- 0
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,8 +1,11 @@
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {


+ 1
- 0
src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;


namespace Tensorflow.ModelSaving namespace Tensorflow.ModelSaving
{ {


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -71,6 +71,7 @@ namespace Tensorflow


public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
public List<IVariableV1> Weights => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();


public Shape OutputShape => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException();


+ 42
- 1
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations
/// ///
/// Callers must ensure all the named tensors are indeed stored in the checkpoint. /// Callers must ensure all the named tensors are indeed stored in the checkpoint.
/// </remarks> /// </remarks>
public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2")
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2")
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
Dictionary<string, object> attrs = new();
attrs["dtypes"] = dtypes;
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"RestoreV2", name, prefix, tensor_names, shape_and_slices
)
{ attrs = attrs });
return result;
}
catch (Exception)
{
try
{
return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx);
}
catch (Exception)
{

}
}
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["prefix"] = prefix; dict["prefix"] = prefix;
dict["tensor_names"] = tensor_names; dict["tensor_names"] = tensor_names;
@@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations
return (tensors); return (tensors);
} }


public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx)
{
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING);
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING);
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING);
object[] attrs = new object[] { "dtypes", dtypes };
Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor };
var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name);

if (execute.must_record_gradient())
{
// TODO(Rinne); record the gradient
}
return result;
}

/// <summary> /// <summary>
/// Reverses specific dimensions of a tensor. /// Reverses specific dimensions of a tensor.
/// </summary> /// </summary>


+ 1
- 0
src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -62,6 +62,7 @@ namespace Tensorflow


public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{ {
// Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`.
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


return _op.outputs; return _op.outputs;


+ 1
- 1
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,8 +17,8 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.ModelSaving;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Variables; using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.CppShapeInferenceResult.Types;




+ 5
- 1
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -1,9 +1,13 @@
namespace Tensorflow
using Newtonsoft.Json;
using Tensorflow.Keras.Common;

namespace Tensorflow
{ {
/// <summary> /// <summary>
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
/// The enum values here are identical to corresponding values in types.proto. /// The enum values here are identical to corresponding values in types.proto.
/// </summary> /// </summary>
[JsonConverter(typeof(CustomizedDTypeJsonConverter))]
public enum TF_DataType public enum TF_DataType
{ {
DtInvalid = 0, DtInvalid = 0,


+ 3
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -159,7 +159,10 @@ namespace Tensorflow
"uint32" => TF_DataType.TF_UINT32, "uint32" => TF_DataType.TF_UINT32,
"int64" => TF_DataType.TF_INT64, "int64" => TF_DataType.TF_INT64,
"uint64" => TF_DataType.TF_UINT64, "uint64" => TF_DataType.TF_UINT64,
"float16" => TF_DataType.TF_BFLOAT16,
"float32" => TF_DataType.TF_FLOAT,
"single" => TF_DataType.TF_FLOAT, "single" => TF_DataType.TF_FLOAT,
"float64" => TF_DataType.TF_DOUBLE,
"double" => TF_DataType.TF_DOUBLE, "double" => TF_DataType.TF_DOUBLE,
"complex" => TF_DataType.TF_COMPLEX128, "complex" => TF_DataType.TF_COMPLEX128,
"string" => TF_DataType.TF_STRING, "string" => TF_DataType.TF_STRING,


+ 18
- 0
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -39,6 +39,24 @@ namespace Tensorflow
_op = value; _op = value;
} }
} }
public BaseResourceVariable variable
{
get
{
if (_op.TryGet<BaseResourceVariable>(out var v))
{
return v;
}
else
{
throw new TypeError("The _op is not a variable.");
}
}
set
{
_op = value;
}
}
public SaveSpec[] specs; public SaveSpec[] specs;
public string name; public string name;
public string device; public string device;


+ 23
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public record class LoadOptions
{
public bool allow_partial_checkpoint;
public string experimental_io_device;
public bool experimental_skip_checkpoint;
public VariablePolicy experimental_variable_policy;

public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null,
bool experimental_skip_checkpoint = false, string experimental_variable_policy = null)
{
this.allow_partial_checkpoint = allow_partial_checkpoint;
this.experimental_io_device = experimental_io_device;
this.experimental_skip_checkpoint = experimental_skip_checkpoint;
this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy);
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Train;
using System;
using Tensorflow.Train;


namespace Tensorflow; namespace Tensorflow;


@@ -14,4 +15,10 @@ public class RevivedTypes
// TODO: complete the implementation. // TODO: complete the implementation.
return null; return null;
} }

public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto)
{
// TODO: complete the implementation.
return null;
}
} }

src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs → src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.ModelSaving
namespace Tensorflow
{ {
/// <summary> /// <summary>
/// Options for saving to SavedModel. /// Options for saving to SavedModel.
@@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving


public bool save_variable_devices() public bool save_variable_devices()
{ {
return this != VariablePolicy.None;
return this != None;
} }


/// <summary> /// <summary>
@@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving
/// <returns></returns> /// <returns></returns>
public static VariablePolicy from_obj(object obj) public static VariablePolicy from_obj(object obj)
{ {
if (obj is null) return VariablePolicy.None;
if (obj is null) return None;
if (obj is VariablePolicy) return (VariablePolicy)obj; if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower(); var key = obj.ToString().ToLower();
return key switch return key switch
{ {
null => VariablePolicy.None,
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
null => None,
"save_variable_devices" => SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
}; };
} }

+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -5,7 +5,6 @@ using System.Linq;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training; using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections; using pbc = global::Google.Protobuf.Collections;


+ 22
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Functions;

namespace Tensorflow.Training.Saving.SavedModel
{
/// <summary>
/// A class wraps a concrete function to handle different distributed contexts.
/// </summary>
internal class WrapperFunction: ConcreteFunction
{
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph)
{
this.forward_backward = concrete_function.forward_backward;
this.Outputs = concrete_function.Outputs;
this.ReturnType = concrete_function.ReturnType;
this.OutputStructure = concrete_function.OutputStructure;
this.ArgKeywords = concrete_function.ArgKeywords;
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Util;

namespace Tensorflow.Training.Saving.SavedModel
{
public static class function_deserialization
{
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions)
{
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName];
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList();
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;

var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
concrete_function.AddTograph();
return concrete_function;
}

private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto)
{
// TODO(Rinne); revise the implementation.
return new FunctionSpec()
{
Fullargspec = function_spec_proto.Fullargspec,
IsMethod = function_spec_proto.IsMethod,
InputSignature = function_spec_proto.InputSignature,
JitCompile = function_spec_proto.JitCompile
};
}
}
}

+ 641
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -0,0 +1,641 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using Tensorflow.Checkpoint;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using System.Runtime.CompilerServices;
using Tensorflow.Variables;
using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow
{
/// <summary>
/// Helper class to load an object-based SavedModel.
/// </summary>
public partial class Loader
{
private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def;
private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes;
private SavedObjectGraph _proto;
private string _export_dir;
private CheckpointOptions _checkpoint_options;
private LoadOptions _save_options;
private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters;
private Dictionary<string, int>? _node_path_to_id;
private List<int>? _filtered_nodes;
private List<int> _ordered_node_ids;
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes;
private List<Trackable> _nodes;
private Dictionary<int, Action<object, object, object>> _node_setters;
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir,
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters)
{
var meta_graph = saved_model_proto.MetaGraphs[0];
_asset_file_def = meta_graph.AssetFileDef;
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
_proto = object_graph_proto;
_export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
_checkpoint_options = ckpt_options;
_save_options = save_options;

// TODO: `this._pretty_printer`

_node_filters = filters;
_node_path_to_id = _convert_node_paths_to_ints();
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
foreach(var filter in filters)
{
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
}

_filtered_nodes = _retrieve_all_filtered_nodes();

_ordered_node_ids = _generate_ordered_node_ids();

_load_all();


if (!save_options.experimental_skip_checkpoint)
{
_restore_checkpoint();
}
foreach(var node in _nodes)
{
// skip the process of `CapturableResource`.
}
}

/// <summary>
/// Maps all string node paths in node_filters to the int node ids.
/// </summary>
/// <returns></returns>
private Dictionary<string, int>? _convert_node_paths_to_ints()
{
if( _node_filters is null)
{
return null;
}
Dictionary<string, int> path_to_int = new();
foreach(var node_id in _node_filters.Keys)
{
int int_node_id;
var node_path = node_id.Split('.');
if (node_path[0] != "root")
{
throw new ValueError($"When passing string identifiers to node_filters, the first name" +
$" must be root. Received {node_path[0]}.");
}
int_node_id = 0;
for(int i = 0; i < node_path.Length - 1; i++)
{
var name = node_path[i + 1];
int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1)));
}
path_to_int[node_id] = int_node_id;
}
return path_to_int;
}

private int _find_node_child(int node_id, string child_name, string path)
{
foreach(var refer in _proto.Nodes[node_id].Children)
{
if(refer.LocalName == child_name)
{
return refer.NodeId;
}
}
throw new ValueError($"Unable to find node {path}.");
}

private List<int>? _retrieve_all_filtered_nodes()
{
if(_node_filters is null)
{
return null;
}

HashSet<int> all_filtered_nodes = new();
Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys);

while(nodes_to_visit.Count > 0)
{
var node_path = nodes_to_visit.Dequeue();
var node_id = _node_path_to_id[node_path];
if (all_filtered_nodes.Contains(node_id))
{
continue;
}
all_filtered_nodes.Add(node_id);
Trackable node = null;
Action<object, object, object> setter = null;
if(_loaded_nodes.TryGetValue(node_id, out var res))
{
(node, setter) = res;
}
if(node is not null)
{
node._maybe_initialize_trackable();
}

foreach(var refer in _proto.Nodes[node_id].Children)
{
Trackable children_object = null;
if(_loaded_nodes.TryGetValue(refer.NodeId, out var result))
{
children_object = result.Item1;
}
// See if node already tracks the child reference, in which case add the child to the loaded_nodes dict.
if(children_object is null && node is not null)
{
children_object = node._lookup_dependency(refer.LocalName);
if(children_object is TrackableDataStructure)
{
// TODO: set setter as lambda.

_loaded_nodes[refer.NodeId] = (children_object, setter);
}
}
string child_path = $"{node_path}.{refer.LocalName}";
_node_path_to_id[child_path] = refer.NodeId;
nodes_to_visit.Enqueue(child_path);
}
}

if (all_filtered_nodes.Contains(0))
{
return null;
}
return all_filtered_nodes.ToList();
}

/// <summary>
/// Orders the node ids so that dependencies appear first.
/// </summary>
/// <returns></returns>
private List<int> _generate_ordered_node_ids()
{
List<int> unordered_ids;
if(_filtered_nodes is null)
{
unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList();
}
else
{
unordered_ids = new List<int>(_filtered_nodes);
}

Dictionary<int, List<int>> dependency_map = new();
foreach(var node_id in unordered_ids)
{
var deps = dependency_map.SetDefault(node_id, new List<int>());
if (_loaded_nodes.ContainsKey(node_id))
{
continue;
}
var proto = _proto.Nodes[node_id];
foreach(var dep in _get_node_dependencies(proto).Values.Distinct())
{
deps.Add(dep);
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep))
{
// TODO: add info with `_pretty_printer`.
throw new ValueError($"Unable to partially load SavedModel since the specified filter " +
$"does not include all required objects for loading (e.g. " +
$"variables used in functions or deserialization dependencies). " +
$"Please include this path in the filter: {dep}");
}
}
int? prev_slot = null;
foreach(var slot_variable_proto in proto.SlotVariables)
{
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId;
// The optimizer and original variable must be created before the slot
// variable, since the slot variable is generated using the Optimizer's
// add_slot API.
var slot_deps = dependency_map[slot_variable_node_id];
slot_deps.Add(node_id);
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId);

if(prev_slot is not null)
{
slot_deps.Add(prev_slot.Value);
}
prev_slot = slot_variable_node_id;
}
}
try
{
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>));
}
catch (TrackableUtils.CyclicDependencyError ex)
{
throw new ValueError("Encountered a cycle in the deserialization dependencies" +
"in the SavedModel. This is extremely unexpected, please" +
"file a bug and make sure you are not manually modifying the SavedModel.");
}
}

/// <summary>
/// Returns a dictionary of all dependencies of an object.
/// </summary>
/// <param name="proto"></param>
/// <returns></returns>
private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto)
{
Dictionary<Maybe<string, int>, int> dependencies = new();
foreach(var refer in proto.Dependencies)
{
dependencies[refer.LocalName] = refer.NodeId;
}
if(proto.KindCase == SavedObject.KindOneofCase.Function)
{
var concreete_functions = proto.Function.ConcreteFunctions;
foreach(var fn_name in concreete_functions)
{
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs)
{
dependencies[bound_input] = bound_input;
}
}
}
else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction)
{
var fn_name = proto.BareConcreteFunction.ConcreteFunctionName;
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs)
{
dependencies[bound_input] = bound_input;
}
}
else if(proto.KindCase == SavedObject.KindOneofCase.Resource)
{
foreach(var child in proto.Children)
{
if(child.LocalName == "_create_resource")
{
dependencies["_create_resource"] = child.NodeId;
}
}
}
return dependencies;
}

/// <summary>
/// Loads all nodes and functions from the SavedModel and their edges.
/// </summary>
private void _load_all()
{
_load_nodes();
_load_edges();

_setup_remaining_functions();
_load_checkpoint_save_and_restore_functions();
}

/// <summary>
/// Restores the checkpoint-related save/restore functions to all nodes.
/// </summary>
private void _load_checkpoint_save_and_restore_functions()
{
foreach(var (node_id, proto) in _iter_all_nodes())
{
var node = get(node_id);
if(node is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
// Restore Trackable serialize- and restore-from-tensor functions.
Debug.Assert(proto.SaveableObjects.Count == 1);
var saveable_object_proto = proto.SaveableObjects.Values.First();
var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction;

throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
else
{
// Restore legacy SaveableObject functions.
Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new();
foreach(var item in proto.SaveableObjects)
{
var name = item.Key;
var saveable_object_proto = item.Value;
var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction;
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id));
}
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
}
}
}

/// <summary>
/// Load all saved objects.
/// </summary>
private void _load_nodes()
{
// `nodes` maps from node ids to recreated objects
// `node_setters` maps from node ids to setter functions
// (same signature as setattr) for setting children.
var (nodes, node_setters) = _initialize_loaded_nodes();

Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)>
slot_variable_node_ids = new();

foreach(var (node_id, proto) in _iter_all_nodes())
{
foreach(var slot_variable_proto in proto.SlotVariables)
{
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId;
slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto);
}
}

// Re-create everything.
foreach (var (node_id, proto) in _iter_all_nodes())
{
if (nodes.ContainsKey(node_id))
{
continue;
}
else if (slot_variable_node_ids.ContainsKey(node_id))
{
// Use the public Optimizer interface when creating slot variables.
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id];
var optimizer_object = nodes[optimizer_node_id];
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];

// TODO: implement it.
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}
else
{
// skip the function and concrete function.
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function)
{
nodes[node_id] = null;
node_setters[node_id] = null;
continue;
}
var (node, setter) = _recreate(proto, node_id, nodes);
nodes[node_id] = node;
node_setters[node_id] = setter;
}
}

if (!nodes.ContainsKey(0))
{
nodes[0] = _recreate_base_user_object().Item1;
}
_nodes = new List<Trackable>();
for(int i = 0; i < _proto.Nodes.Count; i++)
{
_nodes.Add(nodes[i]);
}
_node_setters = node_setters;
}

/// <summary>
/// Load state from checkpoint into the deserialized objects.
/// </summary>
private void _restore_checkpoint()
{
var variables_path = SavedModelUtils.get_variables_path(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0)));
tf.device("CPU");
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
LoadStatus load_status;
if (_save_options.allow_partial_checkpoint)
{
load_status = saver.restore(variables_path, _checkpoint_options).expect_partial();
load_status.assert_nontrivial_match();
}
else
{
load_status = saver.restore(variables_path, _checkpoint_options);
load_status.assert_existing_objects_matched();
}
var ckpt = (load_status as CheckpointLoadStatus).Checkpoint;

if (!tf.Context.executing_eagerly())
{
throw new NotImplementedException("The checkpoint restore has not supported graph mode. " +
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}

/// <summary>
/// Adds edges from objects to other objects and functions.
/// </summary>
private void _load_edges()
{
foreach(var (node_id, object_proto) in _iter_all_nodes())
{
_add_object_graph_edges(object_proto, node_id);
}

if(_filtered_nodes is not null && _filtered_nodes.Contains(0))
{
var root = get(0);
foreach(var node_path in _node_filters.Keys)
{
var loaded_node = _nodes[_node_path_to_id[node_path]];

var path = node_path.Split('.');
var current_node = root;
foreach(var name in path.Skip(1).Take(path.Length - 2))
{
// `hasattr` and `setattr` is used here
throw new NotImplementedException();
}
// `hasattr` and `setattr` is used here
throw new NotImplementedException();
}
}
}

private void _setup_remaining_functions()
{
// TODO: implement it with concrete functions.
}

public Trackable get(int node_id)
{
return _nodes[node_id];
}

public Trackable get(string node_id)
{
return get(_node_path_to_id[node_id]);
}

/// <summary>
/// Adds edges from an object to its children.
/// </summary>
/// <param name="proto"></param>
/// <param name="node_id"></param>
private void _add_object_graph_edges(SavedObject proto, int node_id)
{
var obj = _nodes[node_id];
var setter = _node_setters[node_id];

foreach(var refer in proto.Children)
{
if(obj is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// skip the process of "__call__"
}
}

private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
{
Dictionary<int, Trackable> nodes = new();
Dictionary<int, Action<object, object, object>> node_setters = new();
foreach(var item in _loaded_nodes)
{
var node_id = item.Key;
var (node, setter) = item.Value;
nodes[node_id] = node;
node_setters[node_id] = setter;
}
return (nodes, node_setters);
}

private IEnumerable<(int, SavedObject)> _iter_all_nodes()
{
foreach(var node_id in _ordered_node_ids)
{
yield return (node_id, _proto.Nodes[node_id]);
}
}

private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
{
// skip the registered classes.

Dictionary<Maybe<string, int>, Trackable> dependencies = new();
foreach(var item in _get_node_dependencies(proto))
{
dependencies[item.Key] = nodes[item.Value];
}

return _recreate_default(proto, node_id, dependencies);
}

/// <summary>
/// Creates a Python object from a SavedObject protocol buffer.
/// </summary>
/// <param name="proto"></param>
/// <param name="node_id"></param>
/// <param name="dependencies"></param>
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies)
{
return proto.KindCase switch
{
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
SavedObject.KindOneofCase.Function => throw new NotImplementedException(),
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
};
}

private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id)
{
// skip the check of proto identifier because of lack of property.

var looked_up = RevivedTypes.deserialize(proto);
if(looked_up is null)
{
return _recreate_base_user_object(proto, node_id);
}
return (looked_up.Item1, looked_up.Item2);
}

private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null)
{
return (new _UserObject(), setattr);
}

private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto)
{
string name = proto.Name;
string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>";

// TODO(Rinne): `validate_synchronization_aggregation_trainable`

var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable(
proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name);

var saved_device = proto.Device;
var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device);

if (load_with_device)
{
tf.device(saved_device);
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
}
else
{
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
}
}

private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
{
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
}

// TODO: remove this to a common class.
public static Action<object, object, object> setattr = (x, y, z) =>
{
Debug.Assert(y is string);
var properties = x.GetType().GetProperties();
foreach(var p in properties)
{
if((string)y == p.Name)
{
p.SetValue(x, z);
return;
}
}
// TODO(Rinne): check if the property has been set successfully.
//throw new ValueError($"Cannot find the property {y} of {x}.");
};

public class _UserObject: AutoTrackable
{

}
}
}

+ 122
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs View File

@@ -0,0 +1,122 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Checkpoint;
using Tensorflow.Operations;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class Loader
{
public static SavedModel parse_saved_model(string export_dir)
{
var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT);
var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB);

SavedModel saved_model = new SavedModel();
if (File.Exists(path_to_pb))
{
byte[] file_content;
using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read))
{
file_content = new byte[f.Length];
Debug.Assert(f.Length <= int.MaxValue);
f.Read(file_content, 0, (int)f.Length);
}
// TODO: change to stream mode.
saved_model.MergeFrom(file_content);
return saved_model;
}
else if (File.Exists(path_to_pbtxt))
{
throw new NotImplementedException();
}
else
{
throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" +
$"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}");
}
}

// TODO: revise the type of `tags`
public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null)
{
return load_partial(export_dir, null, tags, options)["root"];
}

public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null)
{
if (options is null)
{
options = new LoadOptions();
}
if (tags is not null)
{
throw new NotImplementedException();
}
var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir);

Trackable root = null;
Loader loader = null;
if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null)
{
// skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)`
var meta_graph_def = saved_model_proto.MetaGraphs[0];
if (!BitConverter.IsLittleEndian)
{
SavedModelUtils.swap_function_tensor_content(meta_graph_def);
}

var object_graph_proto = meta_graph_def.ObjectGraphDef;
var ckpt_options = new CheckpointOptions(options.experimental_io_device);
tf_with(ops.init_scope(), x =>
{
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters);
root = loader.get(0);
// skip the assignment of `graph_debug_info`.
});
// skip the assignment of `tensorflow_version`
// skip the assignment of `tensorflow_git_version`
// skip the process of `metrics`.
}
else
{
if(filters is not null && filters.Count > 0)
{
throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
+ " version) cannot be loaded with node filters.");
}
tf_with(ops.init_scope(), x =>
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
});
}
if(filters != null && filters.Count > 0)
{
return filters.Keys.ToDictionary(x => x, x => loader.get(x));
}
else
{
var res = new Dictionary<string, Trackable>();
res["root"] = root;
return res;
}
}

public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir)
{
var saved_model = parse_saved_model(export_dir);

// TODO: implement debug info.

return (saved_model, null);
}

}
}

+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -6,7 +6,6 @@ using System.Text;
using Google.Protobuf; using Google.Protobuf;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Exceptions; using Tensorflow.Exceptions;
using static Tensorflow.Binding; using static Tensorflow.Binding;


+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs View File

@@ -1,7 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.ModelSaving;


namespace Tensorflow.Training.Saving.SavedModel namespace Tensorflow.Training.Saving.SavedModel
{ {


+ 79
- 15
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -68,6 +68,34 @@ namespace Tensorflow
return saveables.ToArray(); return saveables.ToArray();
} }


public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables)
{
var saveables = new List<MySaveableObject>();
var seen_ops = new List<Tensor>();

foreach (var (name, op) in enumerate(names_to_saveables))
{
foreach (var converted_saveable_object in saveable_objects_for_op(op, name))
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
return saveables.ToArray();
}

public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables)
{
var saveables = new List<MySaveableObject>();
var seen_ops = new List<BaseResourceVariable>();

foreach(var item in names_to_saveables.OrderBy(x => x.Key))
{
foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key))
{
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
}
return saveables.ToArray();
}

private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject
{ {
if (seen_ops.Contains(saveable.op)) if (seen_ops.Contains(saveable.op))
@@ -77,6 +105,15 @@ namespace Tensorflow
seen_ops.Add(saveable.op); seen_ops.Add(saveable.op);
} }


private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable)
{
if (seen_ops.Contains(saveable.variable))
throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}");

saveables.Add(saveable);
seen_ops.Add(saveable.variable);
}

/// <summary> /// <summary>
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`.
/// </summary> /// </summary>
@@ -136,19 +173,20 @@ namespace Tensorflow
{ {
full_name = name + "_" + attr; full_name = name + "_" + attr;
} }
if(factory.TryGet<BaseResourceVariable>(out var variable))
var op = factory(full_name);
if(op.TryGet<BaseResourceVariable>(out var variable))
{ {
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name))
{ {
yield return op;
yield return v;
} }
} }
else else
{ {
var saveable = factory.GetValue<MySaveableObject>();
foreach (var op in saveable_objects_for_op(saveable, saveable.name))
var saveable = op.GetValue<MySaveableObject>();
foreach (var v in saveable_objects_for_op(saveable, saveable.name))
{ {
yield return op;
yield return v;
} }
} }
} }
@@ -214,20 +252,19 @@ namespace Tensorflow
return names_to_saveables; return names_to_saveables;
} }


public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj)
{ {
// skip the process of type `PythonState` // skip the process of type `PythonState`


if (trackable_has_serialize_to_tensor(obj))
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{ {
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME;
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors(); var tensor_dict = obj.serialize_to_tensors();


List<SaveSpec> specs = new(); List<SaveSpec> specs = new();
List<string> local_names = new(); List<string> local_names = new();
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; string prefix = SaveableCompat.get_saveable_name(obj) ?? "";
foreach(var pair in tensor_dict)
foreach (var pair in tensor_dict)
{ {
var tensor_name = pair.Key; var tensor_name = pair.Key;
var maybe_tensor = pair.Value; var maybe_tensor = pair.Value;
@@ -235,9 +272,9 @@ namespace Tensorflow
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); string spec_name = name + TrackableUtils.escape_local_name(tensor_name);


IDictionary<string, Tensor> internal_dict; IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.TryGet<Tensor>(out var tensor))
if (maybe_tensor.TryGet<Tensor>(out var tensor))
{ {
internal_dict= new Dictionary<string, Tensor>();
internal_dict = new Dictionary<string, Tensor>();
internal_dict[""] = tensor; internal_dict[""] = tensor;
} }
else else
@@ -245,13 +282,18 @@ namespace Tensorflow
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
} }


foreach(var item in internal_dict)
foreach (var item in internal_dict)
{ {
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
} }
} }
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return new TrackableSaveable(obj, specs, name, local_names, prefix);
}

if (trackable_has_serialize_to_tensor(obj))
{
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable;
return res; return res;
} }
else else
@@ -333,6 +375,28 @@ namespace Tensorflow
return restored_ops; return restored_ops;
}; };
} }

/// <summary>
/// Returns a dict of SaveableObject factories generated from loaded fns.
/// </summary>
/// <param name="saveable_fn_by_name"></param>
/// <param name="temp_session"></param>
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects(
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session)
{
if (saveable_fn_by_name.Count > 0)
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
return res;
}

public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
bool call_with_mapped_captures = false)
{
return factory(key);
}
} }


public class SaveableCompatibilityConverter: Trackable public class SaveableCompatibilityConverter: Trackable


+ 43
- 10
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -20,8 +20,8 @@ using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;
using Tensorflow.Training; using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Train namespace Tensorflow.Train
@@ -41,9 +41,10 @@ namespace Tensorflow.Train
protected IDictionary<string, Trackable> _unconditional_dependency_names; protected IDictionary<string, Trackable> _unconditional_dependency_names;


protected IList<TrackableReference> _unconditional_checkpoint_dependencies; protected IList<TrackableReference> _unconditional_checkpoint_dependencies;
protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies;


protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories =
new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
private bool _manual_tracking = true; private bool _manual_tracking = true;


private static Trackable _none = new AutoTrackable(); private static Trackable _none = new AutoTrackable();
@@ -71,6 +72,18 @@ namespace Tensorflow.Train
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }
public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies;
public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories
{
get
{
return _self_saveable_object_factories;
}
set
{
_self_saveable_object_factories = value;
}
}


/// <summary> /// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`. /// Restore-on-create for a variable be saved with this `Checkpointable`.
@@ -136,9 +149,11 @@ namespace Tensorflow.Train
_self_update_uid = -1; _self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); _unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>(); _unconditional_dependency_names = new Dictionary<string, Trackable>();
_unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>();
} }


public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache)
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT,
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
_maybe_initialize_trackable(); _maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
@@ -174,10 +189,19 @@ namespace Tensorflow.Train
/// <param name="trackable"></param> /// <param name="trackable"></param>
public virtual void _handle_deferred_dependencies(string name, Trackable trackable) public virtual void _handle_deferred_dependencies(string name, Trackable trackable)
{ {
//_maybe_initialize_trackable();
//trackable._maybe_initialize_trackable();
// TODO: complete the implementation.
_maybe_initialize_trackable();
trackable._maybe_initialize_trackable();

if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies))
{
_unconditional_deferred_dependencies.Remove(name);
foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid))
{
checkpoint_position.restore(trackable);
}
}

// TODO(Rinne): deal with `_self_name_based_restores`
} }


public virtual Trackable? _lookup_dependency(string name) public virtual Trackable? _lookup_dependency(string name)
@@ -225,12 +249,19 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList(); return self_tensor_map.Keys.ToList();
} }


public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{
throw new NotImplementedException();
//return new TrackableSaveable(this, null, name, null, null);
}
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{ {
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
throw new NotImplementedException();
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
res[""] = create_saveable;
return res;
} }
else else
{ {
@@ -259,4 +290,6 @@ namespace Tensorflow.Train
} }


public record class TrackableReference(string Name, Trackable Refer); public record class TrackableReference(string Name, Trackable Refer);

public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName);
} }

+ 4
- 3
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Exceptions; using Tensorflow.Exceptions;
using Tensorflow.Train; using Tensorflow.Train;


@@ -20,9 +21,9 @@ public static class TrackableUtils
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable());
} }
} }
private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
internal static string _ESCAPE_CHAR = ".";
internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{ {


+ 9
- 5
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -5,9 +5,9 @@ using Tensorflow.Variables;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.ModelSaving;
using System.Diagnostics; using System.Diagnostics;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel;


namespace Tensorflow namespace Tensorflow
{ {
@@ -19,7 +19,11 @@ namespace Tensorflow
protected TF_DataType _dtype; protected TF_DataType _dtype;
public TF_DataType dtype => _dtype; public TF_DataType dtype => _dtype;
protected string _handle_name; protected string _handle_name;
protected string handle_name => _handle_name;
public string handle_name
{
get { return _handle_name; }
set { _handle_name = value; }
}


protected string _unique_id; protected string _unique_id;
public string UniqueId => _unique_id; public string UniqueId => _unique_id;
@@ -289,10 +293,10 @@ namespace Tensorflow
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
} }


public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
return res; return res;
} }




+ 18
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -238,5 +238,23 @@ namespace Tensorflow
{ {
return _graph_element.eval(session); return _graph_element.eval(session);
} }

public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable(
VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name)
{
if(aggregation is null)
{
aggregation = VariableAggregation.None;
}
if(synchronization is null)
{
synchronization = VariableSynchronization.Auto;
}
if (trainable is null)
{
trainable = synchronization != VariableSynchronization.OnRead;
}
return (synchronization.Value, aggregation.Value, trainable.Value);
}
} }
} }

+ 3
- 8
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine
/// </summary> /// </summary>
/// <param name="config"></param> /// <param name="config"></param>
/// <returns></returns> /// <returns></returns>
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null)
{ {
// Layer instances created during the graph reconstruction process. // Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary<string, ILayer>();
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>(); var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>(); var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
@@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine
layer = created_layers[layer_name]; layer = created_layers[layer_name];
else else
{ {
layer = layer_data.ClassName switch
{
"InputLayer" => InputLayer.from_config(layer_data.Config),
"Dense" => Dense.from_config(layer_data.Config),
_ => throw new NotImplementedException("")
};
layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config);


created_layers[layer_name] = layer; created_layers[layer_name] = layer;
} }


+ 13
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine
Inputs = inputs, Inputs = inputs,
Outputs = outputs Outputs = outputs
}) })
{
Initialize(inputs, outputs, name);
}

internal void Initialize(Tensors inputs, Tensors outputs, string name = null)
{ {
_input_layers = new List<ILayer>(); _input_layers = new List<ILayer>();
_output_layers = new List<ILayer>(); _output_layers = new List<ILayer>();
@@ -70,7 +75,14 @@ namespace Tensorflow.Keras.Engine
this.inputs = inputs; this.inputs = inputs;
this.outputs = outputs; this.outputs = outputs;
built = true; built = true;
_buildInputShape = inputs.shape;
if(inputs.Length > 0)
{
_buildInputShape = inputs.shape;
}
else
{
_buildInputShape = new Saving.TensorShapeConfig();
}


if (outputs.Any(x => x.KerasHistory == null)) if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs); base_layer_utils.create_keras_history(outputs);


+ 26
- 0
src/TensorFlowNET.Keras/Engine/Layer.Layers.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine


public virtual Shape ComputeOutputShape(Shape input_shape) public virtual Shape ComputeOutputShape(Shape input_shape)
=> throw new NotImplementedException(""); => throw new NotImplementedException("");

protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false)
{
List<IVariableV1> res = new();
var nested_layers = _flatten_layers(false, false);
foreach (var layer in nested_layers)
{
if (layer is Layer l)
{
if (include_trainable == true && include_non_trainable == true)
{
res.AddRange(l.Variables);
}
else if (include_trainable == true && include_non_trainable == false)
{
res.AddRange(l.TrainableVariables);
}
else if(include_trainable == false && include_non_trainable == true)
{
res.AddRange(l.NonTrainableVariables);
}
}
}
return res;
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -12,7 +12,7 @@ public abstract partial class Layer


public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;


public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;
public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata;


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {


+ 70
- 42
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@@ -66,16 +67,74 @@ namespace Tensorflow.Keras.Engine
public bool SupportsMasking { get; set; } public bool SupportsMasking { get; set; }
protected List<IVariableV1> _trainable_weights; protected List<IVariableV1> _trainable_weights;


public virtual List<IVariableV1> TrainableVariables => _trainable_weights;
public virtual List<IVariableV1> TrainableVariables => TrainableWeights;


protected List<IVariableV1> _non_trainable_weights; protected List<IVariableV1> _non_trainable_weights;
public List<IVariableV1> non_trainable_variables => _non_trainable_weights;
public List<IVariableV1> NonTrainableVariables => NonTrainableWeights;
public List<IVariableV1> Variables => Weights;

public virtual List<IVariableV1> TrainableWeights
{
get
{
if (!this.Trainable)
{
return new List<IVariableV1>();
}
var children_weights = _gather_children_variables(true);
return children_weights.Concat(_trainable_weights).Distinct().ToList();
}
}

public virtual List<IVariableV1> NonTrainableWeights
{
get
{
if (!this.Trainable)
{
var children_weights = _gather_children_variables(true, true);
return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList();
}
else
{
var children_weights = _gather_children_variables(include_non_trainable: true);
return children_weights.Concat(_non_trainable_weights).Distinct().ToList();
}
}
}

public virtual List<IVariableV1> Weights
{
get
{
return TrainableWeights.Concat(NonTrainableWeights).ToList();
}
set
{
if (Weights.Count() != value.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(value)}, but the layer was " +
$"expecting {len(Weights)} weights.");
foreach (var (this_w, v_w) in zip(Weights, value))
this_w.assign(v_w, read_value: true);
}
}


protected int id; protected int id;
public int Id => id; public int Id => id;
protected string name; protected string name;
protected string base_name; protected string base_name;
public string Name => name;
public string Name
{
get
{
return name;
}
set
{
name = value;
}
}


protected bool computePreviousMask; protected bool computePreviousMask;
protected List<Operation> updates; protected List<Operation> updates;
@@ -85,10 +144,11 @@ namespace Tensorflow.Keras.Engine


List<INode> inboundNodes; List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes; public List<INode> InboundNodes => inboundNodes;

List<INode> outboundNodes; List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes; public List<INode> OutboundNodes => outboundNodes;


public JObject SerializedAttributes { get; set; }

ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value; public CallContext CallContext => callContext.Value;
public Tensor[] input public Tensor[] input
@@ -117,6 +177,11 @@ namespace Tensorflow.Keras.Engine
protected List<ILayer> _self_tracked_trackables; protected List<ILayer> _self_tracked_trackables;


public Layer(LayerArgs args) public Layer(LayerArgs args)
{
Initialize(args);
}

internal virtual void Initialize(LayerArgs args)
{ {
this.args = args; this.args = args;
// A stateful layer is a layer whose updates are run during inference too, // A stateful layer is a layer whose updates are run during inference too,
@@ -273,46 +338,9 @@ namespace Tensorflow.Keras.Engine
public int count_params() public int count_params()
{ {
if (Trainable) if (Trainable)
return layer_utils.count_params(this, weights);
return layer_utils.count_params(this, Weights);
return 0; return 0;
} }
List<IVariableV1> ILayer.TrainableWeights
{
get
{
return _trainable_weights;
}
}

List<IVariableV1> ILayer.NonTrainableWeights
{
get
{
return _non_trainable_weights;
}
}

public List<IVariableV1> weights
{
get
{
var weights = new List<IVariableV1>();
weights.AddRange(_trainable_weights);
weights.AddRange(_non_trainable_weights);
return weights;
}
set
{
if (weights.Count() != value.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(value)}, but the layer was " +
$"expecting {len(weights)} weights.");
foreach (var (this_w, v_w) in zip(weights, value))
this_w.assign(v_w, read_value: true);
}
}

public List<IVariableV1> Variables => weights;


public virtual IKerasConfig get_config() public virtual IKerasConfig get_config()
=> args; => args;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine
{ {
using (SharedObjectSavingScope.Enter()) using (SharedObjectSavingScope.Enter())
{ {
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
} }
} }
} }


+ 38
- 7
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine
IVariableV1 _predict_counter; IVariableV1 _predict_counter;
bool _base_model_initialized; bool _base_model_initialized;
bool stop_training; bool stop_training;

public bool IsGraphNetwork => _is_graph_network;
public OptimizerV2 Optimizer public OptimizerV2 Optimizer
{ {
@@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine
_init_batch_counters(); _init_batch_counters();
} }


internal override void Initialize(LayerArgs args)
{
_init_batch_counters();
base.Initialize(args);
}

void _configure_steps_per_execution(int steps_per_execution) void _configure_steps_per_execution(int steps_per_execution)
{ {
_steps_per_execution = tf.Variable(steps_per_execution, _steps_per_execution = tf.Variable(steps_per_execution,
@@ -81,10 +89,11 @@ namespace Tensorflow.Keras.Engine
public override List<ILayer> Layers public override List<ILayer> Layers
=> _flatten_layers(recursive: false, include_self: false).ToList(); => _flatten_layers(recursive: false, include_self: false).ToList();


public override List<IVariableV1> TrainableVariables
public override List<IVariableV1> TrainableWeights
{ {
get get
{ {
// skip the assertion of weights created.
var variables = new List<IVariableV1>(); var variables = new List<IVariableV1>();


if (!Trainable) if (!Trainable)
@@ -95,18 +104,40 @@ namespace Tensorflow.Keras.Engine
foreach (var trackable_obj in _self_tracked_trackables) foreach (var trackable_obj in _self_tracked_trackables)
{ {
if (trackable_obj.Trainable) if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.TrainableVariables);
variables.AddRange(trackable_obj.TrainableWeights);
} }


foreach (var layer in _self_tracked_trackables)
variables.AddRange(_trainable_weights);

return variables.Distinct().ToList();
}
}

public override List<IVariableV1> NonTrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();

foreach (var trackable_obj in _self_tracked_trackables)
{ {
if (layer.Trainable)
variables.AddRange(layer.TrainableVariables);
variables.AddRange(trackable_obj.NonTrainableWeights);
} }


// variables.AddRange(_trainable_weights);
if (!Trainable)
{
var trainable_variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(trainable_variables);
variables.AddRange(_trainable_weights);
variables.AddRange(_non_trainable_weights);
}


return variables;
return variables.Distinct().ToList();
} }
} }




+ 10
- 5
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine
: base(args.Inputs, args.Outputs, name: args.Name) : base(args.Inputs, args.Outputs, name: args.Name)
{ {
this.args = args; this.args = args;
if (args.Layers == null)
args.Layers = new List<ILayer>();
// SupportsMasking = true; // SupportsMasking = true;
_compute_output_and_mask_jointly = true; _compute_output_and_mask_jointly = true;
_auto_track_sub_layers = false; _auto_track_sub_layers = false;
@@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine
_created_nodes = new List<INode>(); _created_nodes = new List<INode>();


// Add to the model any layers passed to the constructor. // Add to the model any layers passed to the constructor.
if (args.Layers != null)
if (args.Layers is not null)
{ {
foreach (var layer in args.Layers)
add(layer);
InitLayers(args.Layers);
}
}

public void InitLayers(IEnumerable<ILayer> layers)
{
foreach(var layer in layers)
{
add(layer);
} }
} }




+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers {
{ {
throw new ValueError("Alpha must be a number greater than 0."); throw new ValueError("Alpha must be a number greater than 0.");
} }
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers {
} }
public override void build(Shape input_shape) public override void build(Shape input_shape)
{ {
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
} }
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {


+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers {
if ( alpha < 0f ) { if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0."); throw new ValueError("Alpha must be a number greater than 0.");
} }
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
} }
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs; Tensor output = inputs;


+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers


return outputs; return outputs;
} }

public static Dense from_config(LayerArgs args)
{
return new Dense(args as DenseArgs);
}
} }
} }

+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers
name: Name); name: Name);
} }


public static InputLayer from_config(LayerArgs args)
{
return new InputLayer(args as InputLayerArgs);
}

public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this);
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics


public virtual void reset_states() public virtual void reset_states()
{ {
foreach (var v in weights)
foreach (var v in Weights)
v.assign(0); v.assign(0);
} }




+ 3
- 13
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -4,6 +4,7 @@ using System.IO;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using ThirdParty.Tensorflow.Python.Keras.Protobuf; using ThirdParty.Tensorflow.Python.Keras.Protobuf;


namespace Tensorflow.Keras.Models namespace Tensorflow.Keras.Models
@@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models
public Functional from_config(ModelConfig config) public Functional from_config(ModelConfig config)
=> Functional.from_config(config); => Functional.from_config(config);


public void load_model(string filepath, bool compile = true)
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null)
{ {
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb"));
var saved_mode = SavedModel.Parser.ParseFrom(bytes);
var meta_graph_def = saved_mode.MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;

bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb"));
var metadata = SavedMetadata.Parser.ParseFrom(bytes);

// Recreate layers and metrics using the info stored in the metadata.
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
} }
} }
} }

+ 527
- 34
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -1,12 +1,24 @@
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Training;
using ThirdParty.Tensorflow.Python.Keras.Protobuf; using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


@@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving
{ {
public class KerasObjectLoader public class KerasObjectLoader
{ {
SavedMetadata _metadata;
SavedObjectGraph _proto;
Dictionary<int, string> _node_paths = new Dictionary<int, string>();
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>();
List<int> _traversed_nodes_from_config = new List<int>();
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects;
private SavedMetadata _metadata;
private SavedObjectGraph _proto;
private Dictionary<int, string> _node_paths = new Dictionary<int, string>();
private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>();
private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>();
private List<int> _traversed_nodes_from_config = new List<int>();
private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes;
private List<int> _models_to_reconstruct;
public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes;

static KerasObjectLoader()
{
PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null;
}


public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def)
{ {
_metadata = metadata; _metadata = metadata;
_proto = object_graph_def; _proto = object_graph_def;
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath);
_models_to_reconstruct = new List<int>();
loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
} }


/// <summary> /// <summary>
@@ -42,15 +66,255 @@ namespace Tensorflow.Keras.Saving
continue; continue;
} }


_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
}
foreach(var node_metadata in metric_list)
{
try
{
if (node_metadata.Identifier.Equals("_tf_keras_metric"))
{
continue;
}
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier,
node_metadata.Metadata);
}
catch(ValueError e)
{
if (compile)
{
throw e;
}
// TODO: add logging.warning.
}
}
}

public string get_path(int node_id)
{
return _node_paths[node_id];
}

/// <summary>
/// Finish setting up Keras objects.
///
/// This function is executed after all objects and functions have been created.
/// Call functions and losses are attached to each layer, and once all layers
/// have been fully set up, graph networks are initialized.
///
/// Subclassed models that are revived from the SavedModel are treated like
/// layers, and have their call/loss functions attached here.
/// </summary>
public void finalize_objects()
{
List<Layer> layers_revived_from_config = new();
List<Layer> layers_revived_from_saved_model = new();
foreach(var item in loaded_nodes)
{
var node_id = item.Key;
var node = item.Value.Item1;
if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id))
{
continue;
}

_unblock_model_reconstruction(node_id, node as Layer);

if(node is InputLayer or Metric)
{
continue;
}

// TODO: deal with `RevivedLayer` and `RevivedInputLayer`.
layers_revived_from_config.Add(node as Layer);
}

_finalize_saved_model_layers(layers_revived_from_saved_model);
_finalize_config_layers(layers_revived_from_config);

_reconstruct_all_models();
}

private void _reconstruct_all_models()
{
HashSet<int> all_initialized_models = new();
for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--)
{
int model_id = _models_to_reconstruct[i];
all_initialized_models.Add(model_id);
var (model, layers) = model_layer_dependencies[model_id];
_reconstruct_model(model_id, model, layers.ToList());
_finalize_config_layers(new List<Layer>() { model });
}

Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys));
}

private void _reconstruct_model(int model_id, Model model, List<Layer> layers)
{
var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"];

if(model.input is not null && model.input.Length > 0)
{

}
else if(model is Sequential s)
{
if(layers is null || layers.Count == 0 || layers[0] is not InputLayer)
{
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer")
{
layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
}
else if (config["layers"][0]["config"]["batch_input_shape"] is not null)
{
// TODO(Rinne): implement it
}
}
// `model.__init__(layers, config["name"])`
s.InitLayers(layers);
s.Name = config["name"].ToObject<string>();
if(s.input is null || s.input.Length == 0)
{
var first_layer = _get_child_layer_node_ids(model_id)[0];
var input_specs = _infer_inputs(first_layer);
var input_shapes = _infer_inputs(first_layer, true);
// `model._set_inputs(input_specs)`

// skip the check of input_specs is Dictionary
if (!s.Built)
{
s.build(input_shapes);
}
}
}
else
{
// skip the parameter `created_layers`.
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config),
layers.ToDictionary(x => x.Name, x => x as ILayer));
// skip the `model.__init__`
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>());
(model as Functional).connect_ancillary_layers(created_layers);
}

_set_network_attributes_from_metadata(model);
_unblock_model_reconstruction(model_id, model);
}

private void _set_network_attributes_from_metadata(Model revived_object)
{
// TODO: implement it.
}

/// <summary>
/// Runs the final steps of loading Keras Layers from config.
/// </summary>
/// <param name="layers"></param>
private void _finalize_config_layers(List<Layer> layers)
{
foreach(var layer in layers)
{
if (_is_graph_network(layer))
{
_restore_layer_unconditional_losses(layer);
}
_restore_layer_activation_loss(layer);
_restore_layer_metrics(layer);

// TODO(Rinne): deal with RNN.
}
}

/// <summary>
/// Runs the final steps of loading Keras Layers from SavedModel.
/// </summary>
/// <param name="layers"></param>
private void _finalize_saved_model_layers(List<Layer> layers)
{
foreach(var layer in layers)
{
// TODO(Rinne): deal with `RevivedNetwork`.
_restore_layer_unconditional_losses(layer);
_restore_layer_activation_loss(layer);
_restore_layer_metrics(layer);
}
}

private void _restore_layer_unconditional_losses(Layer layer)
{
// TODO(Rinne): implement it.
}

private void _restore_layer_activation_loss(Layer layer)
{
// TODO(Rinne): implement it.
}

private void _restore_layer_metrics(Layer layer)
{
// TODO(Rinne): implement it.
}

/// <summary>
/// Removes layer from blocking model reconstruction.
/// </summary>
/// <param name="layer_id"></param>
/// <param name="layer"></param>
private void _unblock_model_reconstruction(int layer_id, Layer layer)
{
foreach(var depencency in model_layer_ids_dependencies)
{
var layer_ids = depencency.Value.Item2;
var layers = model_layer_dependencies.SetDefault(depencency.Key,
(depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2;
if (!layer_ids.Contains(layer_id))
{
continue;
}
layers[Array.IndexOf(layer_ids, layer_id)] = layer;
if (layers.All(x => x is not null))
{
_models_to_reconstruct.Add(depencency.Key);
}
} }
} }


void _load_layer(int node_id, string identifier, string metadata_json)
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{ {
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_revive_from_config(identifier, metadata, node_id);

if (loaded_nodes.ContainsKey(node_id))
{
var (node, setter) = loaded_nodes[node_id];

_maybe_add_serialized_attributes(node as Layer, metadata);
var config = metadata.Config;
if(_is_graph_network(node as Layer) && generic_utils.validate_config(config))
{
Debug.Assert(node is Model);
var child_nodes = _get_child_layer_node_ids(node_id);
model_layer_ids_dependencies[node_id] = (node as Model, child_nodes);
if(child_nodes is null || child_nodes.Length == 0)
{
_models_to_reconstruct.Add(node_id);
}
}
return (node, setter);
}
else
{
var (obj, setter) = _revive_from_config(identifier, metadata, node_id);
if (obj is null)
{
(obj, setter) = _revive_custom_object(identifier, metadata);
}
Debug.Assert(obj is Layer);
_maybe_add_serialized_attributes(obj as Layer, metadata);
return (obj, setter);
}
} }


/// <summary> /// <summary>
@@ -59,11 +323,34 @@ namespace Tensorflow.Keras.Saving
/// <param name="identifier"></param> /// <param name="identifier"></param>
/// <param name="metadata"></param> /// <param name="metadata"></param>
/// <param name="node_id"></param> /// <param name="node_id"></param>
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
{ {
var obj = _revive_graph_network(identifier, metadata, node_id);
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
Trackable obj;
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER)
{
// TODO(Rinne): implement it.
return (null, null);
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}
else
{
obj = _revive_graph_network(identifier, metadata, node_id);
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
}

if(obj is null)
{
return (null, null);
}
var setter = _config_node_setter(_revive_setter);
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id);
return (obj, setter);
}

private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
} }


Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id)
@@ -71,6 +358,12 @@ namespace Tensorflow.Keras.Saving
var config = metadata.Config; var config = metadata.Config;
var class_name = metadata.ClassName; var class_name = metadata.ClassName;
Model model = null; Model model = null;

if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional")
{
return null;
}

if (class_name == "Sequential") if (class_name == "Sequential")
{ {
model = new Sequential(new SequentialArgs model = new Sequential(new SequentialArgs
@@ -78,34 +371,82 @@ namespace Tensorflow.Keras.Saving
Name = config.GetValue("name").ToString() Name = config.GetValue("name").ToString()
}); });
} }
else if (class_name == "Functional")
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER)
{ {
throw new NotImplementedException("");
model = new Sequential(new SequentialArgs
{
Name = class_name
});
}
else
{
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>());
} }

if (!metadata.IsGraphNetwork)
return null;


// Record this model and its layers. This will later be used to reconstruct // Record this model and its layers. This will later be used to reconstruct
// the model. // the model.
var layers = _get_child_layer_node_ids(node_id); var layers = _get_child_layer_node_ids(node_id);
model_layer_dependencies[node_id] = (model, layers);
model_layer_ids_dependencies[node_id] = (model, layers);
if(layers is null || layers.Length == 0)
{
_models_to_reconstruct.Add(node_id);
}
return model; return model;
} }


Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
{ {
var config = metadata.Config; var config = metadata.Config;
var class_name = metadata.ClassName; var class_name = metadata.ClassName;
var shared_object_id = metadata.SharedObjectId; var shared_object_id = metadata.SharedObjectId;
var must_restore_from_config = metadata.MustRestoreFromConfig; var must_restore_from_config = metadata.MustRestoreFromConfig;
var obj = class_name switch
{
"Resizing" => Resizing.from_config(config),
_ => throw new NotImplementedException("")
};

var obj = generic_utils.deserialize_keras_object(class_name, config);

obj.Name = metadata.Name;
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec`

var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); var built = _try_build_layer(obj, node_id, metadata.BuildInputShape);
return null;
if (!built)
{
return null;
}
return obj;
}

private void _revive_setter(object layer, object name, object value)
{
Debug.Assert(name is string);
Debug.Assert(layer is Layer);
if(PUBLIC_ATTRIBUTES.ContainsKey(name as string))
{
if(value is Trackable)
{
(layer as Layer)._track_trackable(value as Trackable, name as string);
}
if((layer as Layer).SerializedAttributes is null)
{
(layer as Layer).SerializedAttributes = new JObject();
}
(layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value);
}
else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
{
(layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true);
}
else
{
var properties = layer.GetType().GetProperties();
foreach(var p in properties)
{
if(p.Name == name as string && p.GetValue(layer) is not null)
{
return;
}
}
Loader.setattr(layer, name, value);
}
} }


/// <summary> /// <summary>
@@ -143,34 +484,186 @@ namespace Tensorflow.Keras.Saving
/// <param name="obj"></param> /// <param name="obj"></param>
/// <param name="proto"></param> /// <param name="proto"></param>
/// <param name="node_id"></param> /// <param name="node_id"></param>
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id)
void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id)
{ {
if (_traversed_nodes_from_config.Contains(node_id)) if (_traversed_nodes_from_config.Contains(node_id))
return; return;
var parent_path = _node_paths[node_id]; var parent_path = _node_paths[node_id];
_traversed_nodes_from_config.Add(node_id); _traversed_nodes_from_config.Add(node_id);
if (!obj.Built)
obj._maybe_initialize_trackable();

if(obj is Layer layer && !layer.Built)
{ {
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_try_build_layer(obj, node_id, metadata.BuildInputShape);
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata);
_try_build_layer(layer, node_id, metadata.BuildInputShape);
}


List<(Trackable, int, string)> children = new();
foreach(var refer in proto.Children)
{
var obj_child = obj._lookup_dependency(refer.LocalName);
children.Add((obj_child, refer.NodeId, refer.LocalName));
}

var metric_list_node_id = _search_for_child_node(node_id, new string[] {
Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics"
});
if(metric_list_node_id is not null && obj is Model model && model.metrics is not null)
{
var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x);
foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children)
{
if (obj_metrics.TryGetValue(refer.LocalName, out var metric))
{
var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}";
children.Add((metric as Metric, refer.NodeId, metric_path));
}
}
}

foreach(var (obj_child, child_id, child_name) in children)
{
if(obj_child is null)
{
continue;
}
var child_proto = _proto.Nodes[child_id];

// skip the check for registered identifier

Action<object, object, object> setter;
if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier))
{
setter = _revive_setter;
}
else
{
setter = Loader.setattr;
}

if (loaded_nodes.ContainsKey(child_id))
{
// skip the logging.warning
continue;
}

if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name))
{
(obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0";
}

if(obj_child is TrackableDataStructure)
{
setter = (x, y, z) => { };
}

var child_path = $"{parent_path}.{child_name}";
_node_paths[child_id] = child_path;
_add_children_recreated_from_config(obj_child, child_proto, child_id);
loaded_nodes[child_id] = (obj_child, setter);
} }
} }


bool _try_build_layer(Model obj, int node_id, Shape build_input_shape)
private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape)
{ {
if (obj.Built) if (obj.Built)
return true; return true;


if(build_input_shape is null)
{
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true);
}

if(build_input_shape is not null)
{
obj.build(build_input_shape);
// In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`.
// On the one hand, C# does not support call a method from specified parent class.
// On the other hand, currently All class derived from Layer call `Layer.Build` or
// move the implementation of `Layer.build` to its own `build` method.
// Therefore we do not call it here.
// However, it's still quite risky once in the future a certain class derived from
// `Layer` does not call `Layer.build`.

return true;
}

return false; return false;
} }


bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape)
/// <summary>
/// Infers input shape of layer from SavedModel functions.
/// </summary>
/// <param name="layer_node_id"></param>
/// <param name="convert_to_shapes"></param>
/// <returns></returns>
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false)
{ {
if (obj.Built)
return true;
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" });
if(call_fn_id is null)
{
return null;
}


var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions;
if(concrete_functions is null)
{
return null;
}
var call_fn_name = concrete_functions[0];
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name];
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}

private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child)
{
if(path_to_child is null || path_to_child.Count() == 0)
{
return parent_id;
}

foreach(var child in _proto.Nodes[parent_id].Children)
{
if(child.LocalName == path_to_child.First())
{
return _search_for_child_node(child.NodeId, path_to_child.Skip(1));
}
}
return null;
}

private bool _is_graph_network(Layer layer)
{
// TODO: deal with `RevivedLayer`
if(layer is Functional)
{
return (layer as Functional).IsGraphNetwork || layer is Sequential;
}
return false; return false;
} }

private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata)
{
// TODO: deal with `RevivedLayer`
}

/// <summary>
/// Creates edges for nodes that are recreated from config.
/// </summary>
/// <returns></returns>
private Action<object, object, object> _config_node_setter(Action<object, object, object> setter)
{
void setattr_wrapper(object obj, object name, object value)
{
Debug.Assert(obj is Trackable);
Debug.Assert(name is string);
if((obj as Trackable)._lookup_dependency(name as string) is null)
{
setter(obj, name, value);
}
}
return setattr_wrapper;
}
} }
} }

+ 3
- 3
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;


public partial class KerasSavedModelUtils public partial class KerasSavedModelUtils
{ {
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
SaveOptions? options, bool save_traces = true) SaveOptions? options, bool save_traces = true)
{ {
if (!overwrite && File.Exists(filepath)) if (!overwrite && File.Exists(filepath))
@@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils
BadConsumers = { } BadConsumers = { }
}, },
Identifier = layer.ObjectIdentifier, Identifier = layer.ObjectIdentifier,
Metadata = layer.TrackingMetadata
Metadata = layer.GetTrackingMetadata()
}; };


metadata.Nodes.Add(saved_object); metadata.Nodes.Add(saved_object);
@@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
})); }));
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x =>
{ {
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");


+ 96
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -0,0 +1,96 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class KerasLoadModelUtils
{
/// <summary>
/// Corresponding to keras/saving/save.py/load_model
/// </summary>
/// <param name="filepath"></param>
/// <param name="custom_objects"></param>
/// <param name="compile"></param>
/// <param name="options"></param>
/// <returns></returns>
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
bool compile = true, LoadOptions? options = null)
{
using (SharedObjectSavingScope.Enter())
{
using (LoadContext.load_context(options))
{
if (!File.Exists(filepath) && !Directory.Exists(filepath))
{
throw new IOException($"No file or directory found at {filepath}.");
}
if (Directory.Exists(filepath))
{
return load(filepath, compile, options);
}
else
{
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
}
}
}
}

private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
{
SavedMetadata metadata = new SavedMetadata();
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
if (File.Exists(path_to_metadata_pb))
{
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
}
else
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}

if (metadata.Nodes is null || metadata.Nodes.Count == 0)
{
return Loader.load(path, options: options) as Model;
}

var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);

Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
nodes_to_load["root"] = (null, null);
foreach(var item in keras_loader.LoadedNodes)
{
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
}
var loaded = Loader.load_partial(path, nodes_to_load, options);

keras_loader.finalize_objects();
// keras_loader.del_tracking();

var model = loaded["root"];

if(model is Model && compile)
{
// TODO(Rinne): implement it.
}

if (!tf.Context.executing_eagerly())
{
// TODO(Rinne): implement it.
}

return model;
}
}
}

+ 69
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs View File

@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow.Keras.Saving.SavedModel
{
// TODO: remove this class to common project.
public class ContextHandler: IDisposable
{
public Action<bool> DisposeCallBack { get; set; }
public void Dispose()
{
DisposeCallBack.Invoke(true);
}
}
public class LoadContext
{
private bool _entered_load_context;
private LoadOptions? _load_options;
private static ThreadLocal<LoadContext> _load_context = new();
private LoadContext()
{
_entered_load_context = false;
_load_options = null;
}

public void set_load_options(LoadOptions load_options)
{
_load_options = load_options;
_entered_load_context = true;
}

private void clear_load_options()
{
_load_options = null;
_entered_load_context = false;
}

private LoadOptions? load_options()
{
return _load_options;
}

public static ContextHandler load_context(LoadOptions? load_options)
{
if(_load_context.Value is null)
{
_load_context.Value = new LoadContext();
}
_load_context.Value.set_load_options(load_options);
return new ContextHandler()
{
DisposeCallBack = _ => _load_context.Value.clear_load_options()
};
}

public static LoadOptions? get_load_option()
{
return _load_context.Value.load_options();
}

public static bool in_load_context()
{
return _load_context.Value._entered_load_context;
}
}
}

+ 68
- 0
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -19,15 +19,21 @@ using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;
using Tensorflow.Train;


namespace Tensorflow.Keras.Utils namespace Tensorflow.Keras.Utils
{ {
public class generic_utils public class generic_utils
{ {
private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config";
/// <summary> /// <summary>
/// This method does not have corresponding method in python. It's close to `serialize_keras_object`. /// This method does not have corresponding method in python. It's close to `serialize_keras_object`.
/// </summary> /// </summary>
@@ -51,6 +57,58 @@ namespace Tensorflow.Keras.Utils
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance);
} }


public static Layer deserialize_keras_object(string class_name, JToken config)
{
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
}

public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
}

public static LayerArgs deserialize_layer_args(string class_name, JToken config)
{
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
Debug.Assert(args is LayerArgs);
return args as LayerArgs;
}

public static ModelConfig deserialize_model_config(JToken json)
{
ModelConfig config = new ModelConfig();
config.Name = json["name"].ToObject<string>();
config.Layers = new List<LayerConfig>();
var layersToken = json["layers"];
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);
config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
});
}
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();
config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>();
return config;
}

public static string to_snake_case(string name) public static string to_snake_case(string name)
{ {
return string.Concat(name.Select((x, i) => return string.Concat(name.Select((x, i) =>
@@ -60,5 +118,15 @@ namespace Tensorflow.Keras.Utils
x.ToString(); x.ToString();
})).ToLower(); })).ToLower();
} }

/// <summary>
/// Determines whether config appears to be a valid layer config.
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
public static bool validate_config(JObject config)
{
return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY);
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Utils/layer_utils.cs View File

@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils
} }


var trainable_count = count_params(model, model.TrainableVariables); var trainable_count = count_params(model, model.TrainableVariables);
var non_trainable_count = count_params(model, model.non_trainable_variables);
var non_trainable_count = count_params(model, model.NonTrainableVariables);


print($"Total params: {trainable_count + non_trainable_count}"); print($"Total params: {trainable_count + non_trainable_count}");
print($"Trainable params: {trainable_count}"); print($"Trainable params: {trainable_count}");


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/bias0.npy View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/fingerprint.pb View File


+ 9
- 0
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/keras_metadata.pb View File

@@ -0,0 +1,9 @@

´$root"_tf_keras_network*’${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2
† root.layer-0"_tf_keras_input_layer*Ö{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2
Í root.layer-1"_tf_keras_layer*£{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2
¯root.layer_with_weights-0"_tf_keras_layer*ø{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2
²root.layer_with_weights-1"_tf_keras_layer*û{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2
Š root.layer-4"_tf_keras_layer*à{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2
¹Troot.keras_api.metrics.0"_tf_keras_metric*‚{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2
™Uroot.keras_api.metrics.1"_tf_keras_metric*â{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2

BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/kernel1.npy View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/saved_model.pb View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.data-00000-of-00001 View File


BIN
test/TensorFlowNET.Keras.UnitTest/Assets/simple_model_from_auto_compile/variables/variables.index View File


+ 68
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -0,0 +1,68 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

[TestClass]
public class SequentialModelLoad
{
[TestMethod]
public void SimpleModelFromAutoCompile()
{
var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile");
model.summary();

model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });

// check the weights
var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy");
var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy");

Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second));
Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second));

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 8;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}

[TestMethod]
public void AlexnetFromSequential()
{
new SequentialModelSave().AlexnetFromSequential();
var model = keras.models.load_model(@"./alexnet_from_sequential");
model.summary();

model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;

var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);

model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
}
}

test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs → test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -1,27 +1,21 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Diagnostics;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras; using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Operations;
using System.Diagnostics;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;


namespace TensorFlowNET.Keras.UnitTest.SaveModel; namespace TensorFlowNET.Keras.UnitTest.SaveModel;


[TestClass] [TestClass]
public class SequentialModelTest
public class SequentialModelSave
{ {
[TestMethod] [TestMethod]
public void SimpleModelFromAutoCompile() public void SimpleModelFromAutoCompile()
@@ -63,6 +57,8 @@ public class SequentialModelTest
keras.layers.Softmax(1) keras.layers.Softmax(1)
}); });


model.summary();

model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });


var data_loader = new MnistModelLoader(); var data_loader = new MnistModelLoader();
@@ -82,7 +78,7 @@ public class SequentialModelTest
} }


[TestMethod] [TestMethod]
public void AlexModelFromSequential()
public void AlexnetFromSequential()
{ {
Model model = KerasApi.keras.Sequential(new List<ILayer>() Model model = KerasApi.keras.Sequential(new List<ILayer>()
{ {
@@ -116,7 +112,7 @@ public class SequentialModelTest
keras.layers.Softmax(1) keras.layers.Softmax(1)
}); });


model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" });
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });


var num_epochs = 1; var num_epochs = 1;
var batch_size = 8; var batch_size = 8;
@@ -125,7 +121,7 @@ public class SequentialModelTest


model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);


model.save("./pb_alex_sequential", save_format: "tf");
model.save("./alexnet_from_sequential", save_format: "tf");


// The saved model can be test with the following python code: // The saved model can be test with the following python code:
#region alexnet_python_code #region alexnet_python_code

+ 24
- 0
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -27,4 +27,28 @@
<ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> <ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" />
</ItemGroup> </ItemGroup>


<ItemGroup>
<None Update="Assets\simple_model_from_auto_compile\fingerprint.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\keras_metadata.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\saved_model.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\variables\variables.data-00000-of-00001">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\variables\variables.index">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\kernel1.npy">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Assets\simple_model_from_auto_compile\bias0.npy">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project> </Project>

Loading…
Cancel
Save