* Add CheckpointReader and corresponding C APIs. * Add essential components of SavedModel format loading. * Add checkpoint reading for SavedModel format loading. * Revise customized json converters. * Add support for loading models from python. * Fix the duplicated weights in Keras.Model. * Add alexnet loading test and check for loaded weights. * Fix ci error caused by branch merge. * Resolve the comments and errors. * Fix the stucking of training when loading model. * Fix the stucking of training when loading model. * fix intptr. --------- Co-authored-by: Haiping Chen <haiping008@gmail.com>tags/v0.100.4-load-saved-model
@@ -149,4 +149,22 @@ public static class CheckPointUtils | |||
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||
// } | |||
} | |||
/// <summary> | |||
/// Traverse the object graph and list all accessible objects. | |||
/// </summary> | |||
/// <param name="object_graph_view"></param> | |||
public static IList<Trackable> list_objects(ObjectGraphView graph_view) | |||
{ | |||
return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||
} | |||
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
{ | |||
return full_list.TakeWhile(x => | |||
{ | |||
var saveables = x.gather_saveables_for_checkpoint(); | |||
return saveables is not null && saveables.Count > 0; | |||
}); | |||
} | |||
} |
@@ -0,0 +1,100 @@ | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Checkpoint | |||
{ | |||
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle | |||
{ | |||
public SafeCheckpointReaderHandle(): base() | |||
{ | |||
} | |||
public SafeCheckpointReaderHandle(IntPtr handle): base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TF_DeleteCheckpointReader(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
public class CheckpointReader | |||
{ | |||
private SafeCheckpointReaderHandle _handle; | |||
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | |||
public Dictionary<string, Shape> VariableToShapeMap { get; set; } | |||
public CheckpointReader(string filename) | |||
{ | |||
Status status = new Status(); | |||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||
status.Check(true); | |||
ReadAllShapeAndType(); | |||
} | |||
public int HasTensor(string name) | |||
{ | |||
return c_api.TF_CheckpointReaderHasTensor(_handle, name); | |||
} | |||
/// <summary> | |||
/// Get the variable name. | |||
/// </summary> | |||
/// <param name="index"></param> | |||
/// <returns></returns> | |||
public string GetVariable(int index) | |||
{ | |||
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); | |||
} | |||
public int Size() | |||
{ | |||
return c_api.TF_CheckpointReaderSize(_handle); | |||
} | |||
public TF_DataType GetVariableDataType(string name) | |||
{ | |||
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); | |||
} | |||
public Shape GetVariableShape(string name) | |||
{ | |||
int num_dims = GetVariableNumDims(name); | |||
long[] dims = new long[num_dims]; | |||
Status status = new Status(); | |||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||
status.Check(true); | |||
return new Shape(dims); | |||
} | |||
public int GetVariableNumDims(string name) | |||
{ | |||
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); | |||
} | |||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
Status status = new Status(); | |||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||
status.Check(true); | |||
return new Tensor(tensor); | |||
} | |||
private void ReadAllShapeAndType() | |||
{ | |||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | |||
VariableToShapeMap = new Dictionary<string, Shape>(); | |||
int size = Size(); | |||
for(int i = 0; i < size; i++) | |||
{ | |||
var name = GetVariable(i); | |||
var shape = GetVariableShape(name); | |||
var dtype = GetVariableDataType(name); | |||
VariableToDataTypeMap[name] = dtype; | |||
VariableToShapeMap[name] = shape; | |||
} | |||
} | |||
} | |||
} |
@@ -175,9 +175,9 @@ public static class SaveUtilV1 | |||
{ | |||
var name = factory_data.name; | |||
var key = factory_data.checkpoint_key; | |||
var maybe_saveable = factory_data.factory; | |||
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); | |||
// TODO: oneflow python has a process with callable `saveable_factory`. | |||
// TODO: tensorflow python has a process with callable `saveable_factory`. | |||
List<MySaveableObject> saveables = new(); | |||
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
{ | |||
@@ -217,7 +217,7 @@ public static class SaveUtilV1 | |||
public record class CheckpointFactoryData | |||
( | |||
Maybe<BaseResourceVariable, MySaveableObject> factory, | |||
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
string name, | |||
string checkpoint_key | |||
); |
@@ -0,0 +1,27 @@ | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Checkpoint; | |||
namespace Tensorflow | |||
{ | |||
public unsafe partial class c_api | |||
{ | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern void TF_DeleteCheckpointReader(IntPtr reader); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); | |||
[DllImport(TensorFlowLibName)] | |||
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); | |||
} | |||
} |
@@ -6,8 +6,12 @@ using System.Linq; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Train; | |||
using Tensorflow.Exceptions; | |||
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using Newtonsoft.Json; | |||
using Tensorflow.Training; | |||
namespace Tensorflow.Checkpoint; | |||
@@ -21,8 +25,20 @@ public class TrackableSaver | |||
private TrackableObjectGraph _last_save_object_graph; | |||
private Tensor? _object_graph_feed_tensor = null; | |||
private Tensor? _file_prefix_feed_tensor = null; | |||
private Tensor? _file_prefix_placeholder = null; | |||
private Dictionary<Trackable, Trackable>? _object_map = null; | |||
private object? _cache = null; | |||
public Tensor? FilePrefixPlaceHolder | |||
{ | |||
get | |||
{ | |||
return _file_prefix_placeholder; | |||
} | |||
set | |||
{ | |||
_file_prefix_placeholder = value; | |||
} | |||
} | |||
public TrackableSaver(ObjectGraphView graph_view) | |||
{ | |||
_graph_view = graph_view; | |||
@@ -192,4 +208,366 @@ public class TrackableSaver | |||
return save_path; | |||
} | |||
} | |||
public LoadStatus restore(string? save_path, CheckpointOptions? options = null) | |||
{ | |||
if (options is null) | |||
{ | |||
options = new CheckpointOptions(); | |||
} | |||
if(save_path is null) | |||
{ | |||
return new InitializationOnlyStatus(_graph_view, ops.uid()); | |||
} | |||
CheckpointReader reader = new CheckpointReader(save_path); | |||
bool graph_building = tf.Context.executing_eagerly(); | |||
Dictionary<string, TF_DataType> dtype_map = null; | |||
if (!graph_building) | |||
{ | |||
dtype_map = reader.VariableToDataTypeMap; | |||
} | |||
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
Dictionary<Tensor, string> file_prefix_feed_dict; | |||
Tensor file_prefix_tensor; | |||
if (graph_building) | |||
{ | |||
if(_file_prefix_placeholder is null) | |||
{ | |||
tf.device("/cpu:0"); | |||
_file_prefix_placeholder = constant_op.constant("model"); | |||
} | |||
file_prefix_tensor = _file_prefix_placeholder; | |||
file_prefix_feed_dict = new(); | |||
file_prefix_feed_dict[_file_prefix_placeholder] = save_path; | |||
} | |||
else | |||
{ | |||
tf.device("/cpu:0"); | |||
file_prefix_tensor = constant_op.constant(save_path); | |||
file_prefix_feed_dict = null; | |||
} | |||
TrackableObjectGraph object_graph_proto = new(); | |||
if(object_graph_string.ndim > 0) | |||
{ | |||
object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||
} | |||
else | |||
{ | |||
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); | |||
} | |||
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||
object_graph_proto: object_graph_proto, | |||
save_path: save_path, | |||
save_path_tensor: file_prefix_tensor, | |||
reader: reader, | |||
restore_op_cache: null, | |||
graph_view: _graph_view, | |||
options: options, | |||
saveables_cache: null | |||
); | |||
new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); | |||
if(_graph_view.AttachedDependencies is not null) | |||
{ | |||
foreach(var refer in _graph_view.AttachedDependencies) | |||
{ | |||
if(refer.Name == "root") | |||
{ | |||
continue; | |||
} | |||
int? proto_id = null; | |||
// Find proto ID of attached dependency (if it is in the proto). | |||
foreach (var proto_refer in object_graph_proto.Nodes[0].Children) | |||
{ | |||
if(proto_refer.LocalName == refer.Name) | |||
{ | |||
proto_id = proto_refer.NodeId; | |||
break; | |||
} | |||
} | |||
if (proto_id is null) | |||
{ | |||
continue; | |||
} | |||
// Object has already been restored. This can happen when there's an | |||
// indirect connection from the attached object to the root. | |||
if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) | |||
{ | |||
continue; | |||
} | |||
new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); | |||
} | |||
} | |||
return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); | |||
} | |||
} | |||
public class CheckpointRestoreCoordinator | |||
{ | |||
private CheckpointOptions _options; | |||
private TrackableObjectGraph _object_graph_proto; | |||
private int _restore_uid; | |||
private HashSet<int> _matched_proto_ids; | |||
private Tensor _save_path_tensor; | |||
private string _save_path_string; | |||
private CheckpointReader _reader; | |||
private Dictionary<string, TF_DataType> _dtype_map; | |||
private Dictionary<string, Shape> _shape_map; | |||
private ObjectGraphView _graph_view; | |||
private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations; | |||
private bool _expect_partial_attr; | |||
private List<Operation> _restore_ops; | |||
private List<Trackable> _all_trackables; | |||
private Dictionary<int, Trackable> _object_by_proto_id; | |||
private Dictionary<string, Operation> _restore_ops_by_name; | |||
private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations; | |||
private Dictionary<int, IList<string>> _unused_attributes; | |||
public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||
CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||
{ | |||
// TODO(Rinne): cache. | |||
_options = options; | |||
_object_graph_proto = object_graph_proto; | |||
_restore_uid = ops.uid(); | |||
_save_path_tensor = save_path_tensor; | |||
_save_path_string = save_path; | |||
_reader = reader; | |||
if(_reader is null) | |||
{ | |||
_reader = new CheckpointReader(save_path); | |||
} | |||
_dtype_map = _reader.VariableToDataTypeMap; | |||
_shape_map = _reader.VariableToShapeMap; | |||
_graph_view = graph_view; | |||
_restore_ops = new List<Operation>(); | |||
_restore_ops_by_name = new Dictionary<string, Operation>(); | |||
_all_trackables = new List<Trackable>(); | |||
_matched_proto_ids = new HashSet<int>(); | |||
_object_by_proto_id = new Dictionary<int, Trackable>(); | |||
_slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||
_deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>(); | |||
_expect_partial_attr = false; | |||
for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||
{ | |||
var node = _object_graph_proto.Nodes[i]; | |||
foreach(var slot_reference in node.SlotVariables) | |||
{ | |||
_slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>()) | |||
.Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); | |||
} | |||
} | |||
// skip the deleter and cache. | |||
} | |||
public bool ExpectPartial | |||
{ | |||
get | |||
{ | |||
return _expect_partial_attr; | |||
} | |||
set | |||
{ | |||
_expect_partial_attr = value; | |||
} | |||
} | |||
/// <summary> | |||
/// Corresponding to `all_python_objects` of tensorflow python | |||
/// </summary> | |||
public List<Trackable> AllTrackables => _all_trackables; | |||
public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
public int RestoreUid => _restore_uid; | |||
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations; | |||
public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations; | |||
public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name; | |||
public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes; | |||
public void new_restore_ops(IEnumerable<Operation> new_ops) | |||
{ | |||
_restore_ops.AddRange(new_ops); | |||
// skip the callback. | |||
} | |||
public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
{ | |||
List<Operation> restore_ops = new(); | |||
foreach(var position in positions) | |||
{ | |||
var key = position.ObjectProto.Attributes[0].CheckpointKey; | |||
throw new NotImplementedException(); | |||
} | |||
Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
foreach(var item in tensor_saveables) | |||
{ | |||
if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
{ | |||
variable_dict[item.Key] = variable; | |||
} | |||
else | |||
{ | |||
throw new TypeError(); | |||
} | |||
} | |||
if (tensor_saveables is not null && tensor_saveables.Count > 0) | |||
{ | |||
var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); | |||
var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
foreach(var item in new_restore_ops) | |||
{ | |||
restore_ops.Add(item.Value); | |||
Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); | |||
_restore_ops_by_name[item.Key] = item.Value; | |||
} | |||
} | |||
} | |||
return restore_ops; | |||
} | |||
} | |||
public abstract class LoadStatus | |||
{ | |||
public abstract LoadStatus assert_consumed(); | |||
public abstract LoadStatus assert_existing_objects_matched(); | |||
public abstract LoadStatus assert_nontrivial_match(); | |||
public abstract LoadStatus run_restore_ops(Session? session = null); | |||
public abstract void initialize_or_restore(Session? session = null); | |||
public virtual LoadStatus expect_partial() | |||
{ | |||
return this; | |||
} | |||
} | |||
public class InitializationOnlyStatus: LoadStatus | |||
{ | |||
private int _restore_uid; | |||
private ObjectGraphView _object_graph_view; | |||
private Trackable _root; | |||
public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) | |||
{ | |||
_restore_uid = restore_uid; | |||
_object_graph_view = object_graph_view; | |||
_root = object_graph_view.Root; | |||
} | |||
public override LoadStatus assert_consumed() | |||
{ | |||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
} | |||
public override LoadStatus assert_existing_objects_matched() | |||
{ | |||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
} | |||
public override LoadStatus assert_nontrivial_match() | |||
{ | |||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
} | |||
public override LoadStatus run_restore_ops(Session? session = null) | |||
{ | |||
throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||
+ "(save_path=None to Saver.restore)."); | |||
} | |||
public override void initialize_or_restore(Session? session = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return; | |||
} | |||
if(session is null) | |||
{ | |||
session = new Session(); | |||
} | |||
var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
internal class CheckpointLoadStatus: LoadStatus | |||
{ | |||
private CheckpointRestoreCoordinator _checkpoint; | |||
private Dictionary<Tensor, string> _feed_dict; | |||
private ObjectGraphView _object_graph_view; | |||
private Trackable _root; | |||
public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base() | |||
{ | |||
_checkpoint = checkpoint; | |||
_feed_dict = feed_dict; | |||
_object_graph_view = graph_view; | |||
_root = graph_view.Root; | |||
} | |||
public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
public override LoadStatus assert_consumed() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public override LoadStatus assert_existing_objects_matched() | |||
{ | |||
for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) | |||
{ | |||
var node = _checkpoint.ObjectGraphProto.Nodes[i]; | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && | |||
trackable.UpdateUid < _checkpoint.RestoreUid) | |||
{ | |||
throw new AssertionError($"Object {node} not assigned a value from checkpoint."); | |||
} | |||
} | |||
foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) | |||
{ | |||
if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) | |||
{ | |||
continue; | |||
} | |||
_checkpoint.AllTrackables.Add(trackable_object); | |||
} | |||
var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) | |||
.Except(_checkpoint.ObjectByProtoId.Values); | |||
if (unused_trackables.Any()) | |||
{ | |||
var num_unused_trackables = unused_trackables.Count(); | |||
var num_variables_to_show = Math.Min(10, num_unused_trackables); | |||
throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + | |||
$"not bound to checkpointed values, likely due to changes in the " + | |||
$"Python program. Showing {num_variables_to_show} of " + | |||
$"{num_unused_trackables} unmatched objects: " + | |||
$"{{list(unused_python_objects)[:num_variables_to_show]}}"); | |||
} | |||
return this; | |||
} | |||
public override LoadStatus assert_nontrivial_match() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public override LoadStatus expect_partial() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public override void initialize_or_restore(Session? session = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public override LoadStatus run_restore_ops(Session? session = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} |
@@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint | |||
// tf python has code `with ops.device(restore_device):` here. | |||
tf.device(restore_device); // may be risky. | |||
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
int idx = 0; | |||
@@ -0,0 +1,331 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Checkpoint; | |||
public class CheckpointPosition | |||
{ | |||
private CheckpointRestoreCoordinator _checkpoint; | |||
private int _proto_id; | |||
private bool _skip_restore; | |||
public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) | |||
{ | |||
_checkpoint = checkpoint; | |||
_proto_id = proto_id; | |||
_skip_restore = false; | |||
} | |||
public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||
public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; | |||
public void restore(Trackable trackable) | |||
{ | |||
using (ops.init_scope()) | |||
{ | |||
if (bind_project(trackable)) | |||
{ | |||
var restore_ops = _restore_descendants(); | |||
if(restore_ops is not null && restore_ops.Count > 0) | |||
{ | |||
_checkpoint.new_restore_ops(restore_ops); | |||
} | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// Set a checkpoint<->object correspondence. | |||
/// </summary> | |||
/// <param name="trackable"></param> | |||
/// <returns></returns> | |||
public bool bind_project(Trackable trackable) | |||
{ | |||
_checkpoint.AllTrackables.Add(trackable); | |||
_checkpoint.MatchedProtoIds.Add(_proto_id); | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
{ | |||
// skip the `logging.warning`. | |||
return false; | |||
} | |||
else | |||
{ | |||
_checkpoint.ObjectByProtoId[_proto_id] = trackable; | |||
return true; | |||
} | |||
} | |||
public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
{ | |||
// skip the registered_saver | |||
if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
new List<CheckpointPosition>(), null); | |||
} | |||
var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); | |||
List<Operation> existing_restore_ops; | |||
List<CheckpointPosition> positions = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
{ | |||
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
} | |||
else if(saveable_factories.Count > 0) | |||
{ | |||
(existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return (existing_restore_ops, named_saveables, positions, null); | |||
} | |||
public CheckpointPosition create_child_position(int node_id) | |||
{ | |||
return new CheckpointPosition(_checkpoint, node_id); | |||
} | |||
public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, | |||
int slot_variable_id, string slot_name) | |||
{ | |||
//CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); | |||
// TODO(Rinne): implement it. | |||
return (null, null); | |||
} | |||
/// <summary> | |||
/// Creates a saveable using the _serialize_to_tensor method. | |||
/// </summary> | |||
/// <param name="saveable_factories"></param> | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
{ | |||
string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
suffix = suffix ?? ""; | |||
var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
throw new NotImplementedException("The restore under graph mode has not been implemented. " + | |||
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
// skip the cache. | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
dict[saveable_name] = saveable; | |||
return (new List<Operation>(), dict); | |||
} | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
{ | |||
// TODO(Rinne): implement it. | |||
if(ObjectProto.Attributes is null) | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
} | |||
List<Operation> existing_restore_ops = new(); | |||
HashSet<string> created_compat_names = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
foreach (var serialized_tensor in ObjectProto.Attributes) | |||
{ | |||
Operation existing_op; | |||
if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) | |||
{ | |||
existing_op = null; | |||
} | |||
else | |||
{ | |||
existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; | |||
} | |||
if(existing_op is not null) | |||
{ | |||
existing_restore_ops.Add(existing_op); | |||
continue; | |||
} | |||
if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) | |||
{ | |||
continue; | |||
} | |||
// TODO(Rinne): deal with cache. | |||
var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); | |||
if(saveable is null) | |||
{ | |||
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
continue; | |||
} | |||
named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
} | |||
return (existing_restore_ops, named_saveables); | |||
} | |||
private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
{ | |||
var expected_factory_name = serialized_tensor.Name; | |||
var factory_input_name = serialized_tensor.CheckpointKey; | |||
if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) | |||
{ | |||
foreach(var item in saveable_factories) | |||
{ | |||
var factory_name = item.Key; | |||
var factory = item.Value; | |||
if (expected_factory_name.StartsWith(factory_name)) | |||
{ | |||
if(matched_factory is not null) | |||
{ | |||
throw new ValueError($"Forward compatibility load error: Unable to load " + | |||
"checkpoint saved in future version of TensorFlow. " + | |||
"Please update your version of TensorFlow to the " + | |||
"version in which the checkpoint was saved."); | |||
} | |||
} | |||
matched_factory = factory; | |||
factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; | |||
created_compat_names.Add(factory_name); | |||
} | |||
} | |||
return matched_factory(factory_input_name); | |||
} | |||
private string _extract_saveable_name(string checkpoint_key) | |||
{ | |||
var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; | |||
return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); | |||
} | |||
/// <summary> | |||
/// Restore the bound Trackable and dependencies (may be deferred). | |||
/// </summary> | |||
private List<Operation> _restore_descendants() | |||
{ | |||
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
visit_queue.Enqueue((this, this.Trackable)); | |||
List<Operation> restore_ops = new(); | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
List<CheckpointPosition> positions = new(); | |||
CheckpointPosition current_position = null; | |||
while (visit_queue.Count > 0) | |||
{ | |||
current_position = visit_queue.Dequeue().Item1; | |||
var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); | |||
restore_ops.AddRange(new_restore_ops); | |||
foreach(var item in new_tensor_saveables) | |||
{ | |||
tensor_saveables.Add(item.Key, item.Value); | |||
} | |||
positions.AddRange(new_positions); | |||
_queue_children_for_restoration(current_position, visit_queue); | |||
_queue_slot_variables(current_position, visit_queue); | |||
} | |||
restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); | |||
return restore_ops; | |||
} | |||
private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
{ | |||
var trackable = checkpoint_position.Trackable; | |||
foreach(var child in checkpoint_position.ObjectProto.Children) | |||
{ | |||
var child_position = checkpoint_position.create_child_position(child.NodeId); | |||
var local_object = trackable._lookup_dependency(child.LocalName); | |||
var child_proto = child_position.ObjectProto; | |||
if(local_object is null) | |||
{ | |||
if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) | |||
{ | |||
trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position); | |||
} | |||
} | |||
else | |||
{ | |||
if (child_position.bind_project(local_object)) | |||
{ | |||
visit_queue.Enqueue((child_position, local_object)); | |||
} | |||
} | |||
} | |||
} | |||
private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
{ | |||
var trackable = checkpoint_position.Trackable; | |||
var checkpoint = checkpoint_position.Checkpoint; | |||
if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) | |||
{ | |||
checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); | |||
foreach (var deferred_slot_restoration in positions) | |||
{ | |||
var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( | |||
trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, | |||
deferred_slot_restoration.SlotName | |||
); | |||
if(slot_variable_position is not null) | |||
{ | |||
visit_queue.Enqueue((slot_variable_position, slot_variable)); | |||
} | |||
} | |||
} | |||
if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) | |||
{ | |||
checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); | |||
foreach (var slot_restoration in restorations) | |||
{ | |||
if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) | |||
{ | |||
throw new NotImplementedException(); | |||
// TODO(Rinne); implement it. | |||
} | |||
else | |||
{ | |||
Debug.Assert(trackable is BaseResourceVariable); | |||
Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>()) | |||
.Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); | |||
} | |||
} | |||
} | |||
} | |||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
{ | |||
var trackable = this.Trackable; | |||
trackable._maybe_initialize_trackable(); | |||
if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||
{ | |||
var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); | |||
trackable.UpdateUid = _checkpoint.RestoreUid; | |||
return (restore_ops, tensor_saveables, positions, registered_savers); | |||
} | |||
else | |||
{ | |||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
new List<CheckpointPosition>(), null); | |||
} | |||
} | |||
} | |||
public record class DeferredSlotVariableRestoration( | |||
BaseResourceVariable OriginalVariable, | |||
int SlotVariableId, | |||
string SlotName | |||
); |
@@ -10,7 +10,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow.Eager | |||
{ | |||
internal class execute | |||
internal static class execute | |||
{ | |||
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||
{ | |||
@@ -27,5 +27,9 @@ namespace Tensorflow.Eager | |||
return tensors; | |||
} | |||
public static bool must_record_gradient() | |||
{ | |||
return false; | |||
} | |||
} | |||
} |
@@ -13,8 +13,8 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
public class ConcreteFunction: Trackable | |||
{ | |||
FuncGraph func_graph; | |||
ForwardBackwardCall forward_backward; | |||
internal FuncGraph func_graph; | |||
internal ForwardBackwardCall forward_backward; | |||
public Tensor[] Inputs => func_graph.Inputs; | |||
public Tensor[] CapturedInputs => func_graph.external_captures; | |||
@@ -23,6 +23,8 @@ namespace Tensorflow.Functions | |||
public Tensor[] Outputs; | |||
public Type ReturnType; | |||
public TensorSpec[] OutputStructure; | |||
public IEnumerable<string> ArgKeywords { get; set; } | |||
public long NumPositionArgs { get; set; } | |||
public ConcreteFunction(string name) | |||
{ | |||
@@ -163,6 +165,15 @@ namespace Tensorflow.Functions | |||
return flat_outputs; | |||
} | |||
public void AddTograph(Graph? g = null) | |||
{ | |||
if(!tf.Context.executing_eagerly() && g is null) | |||
{ | |||
g = ops.get_default_graph(); | |||
} | |||
// TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
@@ -16,8 +16,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.IO | |||
{ | |||
@@ -63,5 +65,15 @@ namespace Tensorflow.IO | |||
dirs.AddRange(Directory.GetFiles(dir)); | |||
return dirs.ToArray(); | |||
} | |||
public string join(params string[] paths) | |||
{ | |||
Debug.Assert(paths.Length >= 1); | |||
if (paths[0].Substring(1).Contains("://")) | |||
{ | |||
throw new NotImplementedException("The combination of urls has not been implemented."); | |||
} | |||
return Path.Combine(paths); | |||
} | |||
} | |||
} |
@@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
var axis = serializer.Deserialize(reader, typeof(long[])); | |||
int[]? axis; | |||
if(reader.ValueType == typeof(long)) | |||
{ | |||
axis = new int[1]; | |||
axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||
} | |||
else | |||
{ | |||
axis = serializer.Deserialize(reader, typeof(int[])) as int[]; | |||
} | |||
if (axis is null) | |||
{ | |||
throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||
@@ -0,0 +1,36 @@ | |||
using Newtonsoft.Json.Linq; | |||
using Newtonsoft.Json; | |||
namespace Tensorflow.Keras.Common | |||
{ | |||
public class CustomizedDTypeJsonConverter : JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
{ | |||
return objectType == typeof(TF_DataType); | |||
} | |||
public override bool CanRead => true; | |||
public override bool CanWrite => true; | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); | |||
token.WriteTo(writer); | |||
} | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
if (reader.ValueType == typeof(string)) | |||
{ | |||
var str = (string)serializer.Deserialize(reader, typeof(string)); | |||
return dtypes.tf_dtype_from_name(str); | |||
} | |||
else | |||
{ | |||
return (TF_DataType)serializer.Deserialize(reader, typeof(int)); | |||
} | |||
} | |||
} | |||
} |
@@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common | |||
{ | |||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
} | |||
if(values.Length != 3) | |||
if(values.Length == 1) | |||
{ | |||
var array = values[0] as JArray; | |||
if(array is null) | |||
{ | |||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
} | |||
values = array.ToObject<object[]>(); | |||
} | |||
if (values.Length < 3) | |||
{ | |||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
} | |||
@@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common | |||
{ | |||
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||
} | |||
if (values[1] is not int) | |||
int nodeIndex; | |||
int tensorIndex; | |||
if (values[1] is long) | |||
{ | |||
nodeIndex = (int)(long)values[1]; | |||
} | |||
else if (values[1] is int) | |||
{ | |||
nodeIndex = (int)values[1]; | |||
} | |||
else | |||
{ | |||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||
} | |||
if (values[2] is not int) | |||
if (values[2] is long) | |||
{ | |||
tensorIndex = (int)(long)values[2]; | |||
} | |||
else if (values[1] is int) | |||
{ | |||
tensorIndex = (int)values[2]; | |||
} | |||
else | |||
{ | |||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||
} | |||
return new NodeConfig() | |||
{ | |||
Name = values[0] as string, | |||
NodeIndex = (int)values[1], | |||
TensorIndex = (int)values[2] | |||
NodeIndex = nodeIndex, | |||
TensorIndex = tensorIndex | |||
}; | |||
} | |||
} | |||
@@ -51,10 +51,28 @@ namespace Tensorflow.Keras.Common | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
if(dims is null) | |||
long?[] dims; | |||
try | |||
{ | |||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
catch (JsonSerializationException ex) | |||
{ | |||
if (reader.Value.Equals("class_name")) | |||
{ | |||
reader.Read(); | |||
reader.Read(); | |||
reader.Read(); | |||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
} | |||
else | |||
{ | |||
throw ex; | |||
} | |||
} | |||
if (dims is null) | |||
{ | |||
return null; | |||
} | |||
long[] convertedDims = new long[dims.Length]; | |||
for(int i = 0; i < dims.Length; i++) | |||
@@ -19,6 +19,7 @@ namespace Tensorflow.Keras | |||
List<IVariableV1> TrainableVariables { get; } | |||
List<IVariableV1> TrainableWeights { get; } | |||
List<IVariableV1> NonTrainableWeights { get; } | |||
List<IVariableV1> Weights { get; } | |||
Shape OutputShape { get; } | |||
Shape BatchInputShape { get; } | |||
TensorShapeConfig BuildInputShape { get; } | |||
@@ -1,8 +1,11 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
namespace Tensorflow.ModelSaving | |||
{ | |||
@@ -71,6 +71,7 @@ namespace Tensorflow | |||
public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||
public List<IVariableV1> Weights => throw new NotImplementedException(); | |||
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||
public Shape OutputShape => throw new NotImplementedException(); | |||
@@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations | |||
/// | |||
/// Callers must ensure all the named tensors are indeed stored in the checkpoint. | |||
/// </remarks> | |||
public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
Dictionary<string, object> attrs = new(); | |||
attrs["dtypes"] = dtypes; | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"RestoreV2", name, prefix, tensor_names, shape_and_slices | |||
) | |||
{ attrs = attrs }); | |||
return result; | |||
} | |||
catch (Exception) | |||
{ | |||
try | |||
{ | |||
return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
} | |||
var dict = new Dictionary<string, object>(); | |||
dict["prefix"] = prefix; | |||
dict["tensor_names"] = tensor_names; | |||
@@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations | |||
return (tensors); | |||
} | |||
public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) | |||
{ | |||
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
object[] attrs = new object[] { "dtypes", dtypes }; | |||
Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||
var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
{ | |||
// TODO(Rinne); record the gradient | |||
} | |||
return result; | |||
} | |||
/// <summary> | |||
/// Reverses specific dimensions of a tensor. | |||
/// </summary> | |||
@@ -62,6 +62,7 @@ namespace Tensorflow | |||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
{ | |||
// Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. | |||
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||
return _op.outputs; | |||
@@ -17,8 +17,8 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.ModelSaving; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Variables; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
@@ -1,9 +1,13 @@ | |||
namespace Tensorflow | |||
using Newtonsoft.Json; | |||
using Tensorflow.Keras.Common; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||
/// The enum values here are identical to corresponding values in types.proto. | |||
/// </summary> | |||
[JsonConverter(typeof(CustomizedDTypeJsonConverter))] | |||
public enum TF_DataType | |||
{ | |||
DtInvalid = 0, | |||
@@ -159,7 +159,10 @@ namespace Tensorflow | |||
"uint32" => TF_DataType.TF_UINT32, | |||
"int64" => TF_DataType.TF_INT64, | |||
"uint64" => TF_DataType.TF_UINT64, | |||
"float16" => TF_DataType.TF_BFLOAT16, | |||
"float32" => TF_DataType.TF_FLOAT, | |||
"single" => TF_DataType.TF_FLOAT, | |||
"float64" => TF_DataType.TF_DOUBLE, | |||
"double" => TF_DataType.TF_DOUBLE, | |||
"complex" => TF_DataType.TF_COMPLEX128, | |||
"string" => TF_DataType.TF_STRING, | |||
@@ -39,6 +39,24 @@ namespace Tensorflow | |||
_op = value; | |||
} | |||
} | |||
public BaseResourceVariable variable | |||
{ | |||
get | |||
{ | |||
if (_op.TryGet<BaseResourceVariable>(out var v)) | |||
{ | |||
return v; | |||
} | |||
else | |||
{ | |||
throw new TypeError("The _op is not a variable."); | |||
} | |||
} | |||
set | |||
{ | |||
_op = value; | |||
} | |||
} | |||
public SaveSpec[] specs; | |||
public string name; | |||
public string device; | |||
@@ -0,0 +1,23 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public record class LoadOptions | |||
{ | |||
public bool allow_partial_checkpoint; | |||
public string experimental_io_device; | |||
public bool experimental_skip_checkpoint; | |||
public VariablePolicy experimental_variable_policy; | |||
public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, | |||
bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) | |||
{ | |||
this.allow_partial_checkpoint = allow_partial_checkpoint; | |||
this.experimental_io_device = experimental_io_device; | |||
this.experimental_skip_checkpoint = experimental_skip_checkpoint; | |||
this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); | |||
} | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using Tensorflow.Train; | |||
using System; | |||
using Tensorflow.Train; | |||
namespace Tensorflow; | |||
@@ -14,4 +15,10 @@ public class RevivedTypes | |||
// TODO: complete the implementation. | |||
return null; | |||
} | |||
public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||
{ | |||
// TODO: complete the implementation. | |||
return null; | |||
} | |||
} |
@@ -2,7 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.ModelSaving | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Options for saving to SavedModel. | |||
@@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving | |||
public bool save_variable_devices() | |||
{ | |||
return this != VariablePolicy.None; | |||
return this != None; | |||
} | |||
/// <summary> | |||
@@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving | |||
/// <returns></returns> | |||
public static VariablePolicy from_obj(object obj) | |||
{ | |||
if (obj is null) return VariablePolicy.None; | |||
if (obj is null) return None; | |||
if (obj is VariablePolicy) return (VariablePolicy)obj; | |||
var key = obj.ToString().ToLower(); | |||
return key switch | |||
{ | |||
null => VariablePolicy.None, | |||
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||
null => None, | |||
"save_variable_devices" => SAVE_VARIABLE_DEVICES, | |||
"expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, | |||
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||
}; | |||
} |
@@ -5,7 +5,6 @@ using System.Linq; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Functions; | |||
using Tensorflow.ModelSaving; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using pbc = global::Google.Protobuf.Collections; | |||
@@ -0,0 +1,22 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Functions; | |||
namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
/// <summary> | |||
/// A class wraps a concrete function to handle different distributed contexts. | |||
/// </summary> | |||
internal class WrapperFunction: ConcreteFunction | |||
{ | |||
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||
{ | |||
this.forward_backward = concrete_function.forward_backward; | |||
this.Outputs = concrete_function.Outputs; | |||
this.ReturnType = concrete_function.ReturnType; | |||
this.OutputStructure = concrete_function.OutputStructure; | |||
this.ArgKeywords = concrete_function.ArgKeywords; | |||
} | |||
} | |||
} |
@@ -0,0 +1,36 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
public static class function_deserialization | |||
{ | |||
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
IDictionary<string, ConcreteFunction> concrete_functions) | |||
{ | |||
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; | |||
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
concrete_function.AddTograph(); | |||
return concrete_function; | |||
} | |||
private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) | |||
{ | |||
// TODO(Rinne); revise the implementation. | |||
return new FunctionSpec() | |||
{ | |||
Fullargspec = function_spec_proto.Fullargspec, | |||
IsMethod = function_spec_proto.IsMethod, | |||
InputSignature = function_spec_proto.InputSignature, | |||
JitCompile = function_spec_proto.JitCompile | |||
}; | |||
} | |||
} | |||
} |
@@ -0,0 +1,641 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Net.Sockets; | |||
using System.Text; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using static Tensorflow.Binding; | |||
using System.Runtime.CompilerServices; | |||
using Tensorflow.Variables; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Helper class to load an object-based SavedModel. | |||
/// </summary> | |||
public partial class Loader | |||
{ | |||
private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def; | |||
private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes; | |||
private SavedObjectGraph _proto; | |||
private string _export_dir; | |||
private CheckpointOptions _checkpoint_options; | |||
private LoadOptions _save_options; | |||
private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters; | |||
private Dictionary<string, int>? _node_path_to_id; | |||
private List<int>? _filtered_nodes; | |||
private List<int> _ordered_node_ids; | |||
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
private List<Trackable> _nodes; | |||
private Dictionary<int, Action<object, object, object>> _node_setters; | |||
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||
{ | |||
var meta_graph = saved_model_proto.MetaGraphs[0]; | |||
_asset_file_def = meta_graph.AssetFileDef; | |||
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||
_proto = object_graph_proto; | |||
_export_dir = export_dir; | |||
// TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
_checkpoint_options = ckpt_options; | |||
_save_options = save_options; | |||
// TODO: `this._pretty_printer` | |||
_node_filters = filters; | |||
_node_path_to_id = _convert_node_paths_to_ints(); | |||
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
foreach(var filter in filters) | |||
{ | |||
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; | |||
} | |||
_filtered_nodes = _retrieve_all_filtered_nodes(); | |||
_ordered_node_ids = _generate_ordered_node_ids(); | |||
_load_all(); | |||
if (!save_options.experimental_skip_checkpoint) | |||
{ | |||
_restore_checkpoint(); | |||
} | |||
foreach(var node in _nodes) | |||
{ | |||
// skip the process of `CapturableResource`. | |||
} | |||
} | |||
/// <summary> | |||
/// Maps all string node paths in node_filters to the int node ids. | |||
/// </summary> | |||
/// <returns></returns> | |||
private Dictionary<string, int>? _convert_node_paths_to_ints() | |||
{ | |||
if( _node_filters is null) | |||
{ | |||
return null; | |||
} | |||
Dictionary<string, int> path_to_int = new(); | |||
foreach(var node_id in _node_filters.Keys) | |||
{ | |||
int int_node_id; | |||
var node_path = node_id.Split('.'); | |||
if (node_path[0] != "root") | |||
{ | |||
throw new ValueError($"When passing string identifiers to node_filters, the first name" + | |||
$" must be root. Received {node_path[0]}."); | |||
} | |||
int_node_id = 0; | |||
for(int i = 0; i < node_path.Length - 1; i++) | |||
{ | |||
var name = node_path[i + 1]; | |||
int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); | |||
} | |||
path_to_int[node_id] = int_node_id; | |||
} | |||
return path_to_int; | |||
} | |||
private int _find_node_child(int node_id, string child_name, string path) | |||
{ | |||
foreach(var refer in _proto.Nodes[node_id].Children) | |||
{ | |||
if(refer.LocalName == child_name) | |||
{ | |||
return refer.NodeId; | |||
} | |||
} | |||
throw new ValueError($"Unable to find node {path}."); | |||
} | |||
private List<int>? _retrieve_all_filtered_nodes() | |||
{ | |||
if(_node_filters is null) | |||
{ | |||
return null; | |||
} | |||
HashSet<int> all_filtered_nodes = new(); | |||
Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys); | |||
while(nodes_to_visit.Count > 0) | |||
{ | |||
var node_path = nodes_to_visit.Dequeue(); | |||
var node_id = _node_path_to_id[node_path]; | |||
if (all_filtered_nodes.Contains(node_id)) | |||
{ | |||
continue; | |||
} | |||
all_filtered_nodes.Add(node_id); | |||
Trackable node = null; | |||
Action<object, object, object> setter = null; | |||
if(_loaded_nodes.TryGetValue(node_id, out var res)) | |||
{ | |||
(node, setter) = res; | |||
} | |||
if(node is not null) | |||
{ | |||
node._maybe_initialize_trackable(); | |||
} | |||
foreach(var refer in _proto.Nodes[node_id].Children) | |||
{ | |||
Trackable children_object = null; | |||
if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) | |||
{ | |||
children_object = result.Item1; | |||
} | |||
// See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. | |||
if(children_object is null && node is not null) | |||
{ | |||
children_object = node._lookup_dependency(refer.LocalName); | |||
if(children_object is TrackableDataStructure) | |||
{ | |||
// TODO: set setter as lambda. | |||
_loaded_nodes[refer.NodeId] = (children_object, setter); | |||
} | |||
} | |||
string child_path = $"{node_path}.{refer.LocalName}"; | |||
_node_path_to_id[child_path] = refer.NodeId; | |||
nodes_to_visit.Enqueue(child_path); | |||
} | |||
} | |||
if (all_filtered_nodes.Contains(0)) | |||
{ | |||
return null; | |||
} | |||
return all_filtered_nodes.ToList(); | |||
} | |||
/// <summary> | |||
/// Orders the node ids so that dependencies appear first. | |||
/// </summary> | |||
/// <returns></returns> | |||
private List<int> _generate_ordered_node_ids() | |||
{ | |||
List<int> unordered_ids; | |||
if(_filtered_nodes is null) | |||
{ | |||
unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); | |||
} | |||
else | |||
{ | |||
unordered_ids = new List<int>(_filtered_nodes); | |||
} | |||
Dictionary<int, List<int>> dependency_map = new(); | |||
foreach(var node_id in unordered_ids) | |||
{ | |||
var deps = dependency_map.SetDefault(node_id, new List<int>()); | |||
if (_loaded_nodes.ContainsKey(node_id)) | |||
{ | |||
continue; | |||
} | |||
var proto = _proto.Nodes[node_id]; | |||
foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
{ | |||
deps.Add(dep); | |||
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||
{ | |||
// TODO: add info with `_pretty_printer`. | |||
throw new ValueError($"Unable to partially load SavedModel since the specified filter " + | |||
$"does not include all required objects for loading (e.g. " + | |||
$"variables used in functions or deserialization dependencies). " + | |||
$"Please include this path in the filter: {dep}"); | |||
} | |||
} | |||
int? prev_slot = null; | |||
foreach(var slot_variable_proto in proto.SlotVariables) | |||
{ | |||
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
// The optimizer and original variable must be created before the slot | |||
// variable, since the slot variable is generated using the Optimizer's | |||
// add_slot API. | |||
var slot_deps = dependency_map[slot_variable_node_id]; | |||
slot_deps.Add(node_id); | |||
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||
if(prev_slot is not null) | |||
{ | |||
slot_deps.Add(prev_slot.Value); | |||
} | |||
prev_slot = slot_variable_node_id; | |||
} | |||
} | |||
try | |||
{ | |||
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||
} | |||
catch (TrackableUtils.CyclicDependencyError ex) | |||
{ | |||
throw new ValueError("Encountered a cycle in the deserialization dependencies" + | |||
"in the SavedModel. This is extremely unexpected, please" + | |||
"file a bug and make sure you are not manually modifying the SavedModel."); | |||
} | |||
} | |||
/// <summary> | |||
/// Returns a dictionary of all dependencies of an object. | |||
/// </summary> | |||
/// <param name="proto"></param> | |||
/// <returns></returns> | |||
private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||
{ | |||
Dictionary<Maybe<string, int>, int> dependencies = new(); | |||
foreach(var refer in proto.Dependencies) | |||
{ | |||
dependencies[refer.LocalName] = refer.NodeId; | |||
} | |||
if(proto.KindCase == SavedObject.KindOneofCase.Function) | |||
{ | |||
var concreete_functions = proto.Function.ConcreteFunctions; | |||
foreach(var fn_name in concreete_functions) | |||
{ | |||
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
{ | |||
dependencies[bound_input] = bound_input; | |||
} | |||
} | |||
} | |||
else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) | |||
{ | |||
var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; | |||
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
{ | |||
dependencies[bound_input] = bound_input; | |||
} | |||
} | |||
else if(proto.KindCase == SavedObject.KindOneofCase.Resource) | |||
{ | |||
foreach(var child in proto.Children) | |||
{ | |||
if(child.LocalName == "_create_resource") | |||
{ | |||
dependencies["_create_resource"] = child.NodeId; | |||
} | |||
} | |||
} | |||
return dependencies; | |||
} | |||
/// <summary> | |||
/// Loads all nodes and functions from the SavedModel and their edges. | |||
/// </summary> | |||
private void _load_all() | |||
{ | |||
_load_nodes(); | |||
_load_edges(); | |||
_setup_remaining_functions(); | |||
_load_checkpoint_save_and_restore_functions(); | |||
} | |||
/// <summary> | |||
/// Restores the checkpoint-related save/restore functions to all nodes. | |||
/// </summary> | |||
private void _load_checkpoint_save_and_restore_functions() | |||
{ | |||
foreach(var (node_id, proto) in _iter_all_nodes()) | |||
{ | |||
var node = get(node_id); | |||
if(node is null) | |||
{ | |||
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
continue; | |||
} | |||
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
{ | |||
// Restore Trackable serialize- and restore-from-tensor functions. | |||
Debug.Assert(proto.SaveableObjects.Count == 1); | |||
var saveable_object_proto = proto.SaveableObjects.Values.First(); | |||
var save_fn_id = saveable_object_proto.SaveFunction; | |||
var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
else | |||
{ | |||
// Restore legacy SaveableObject functions. | |||
Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new(); | |||
foreach(var item in proto.SaveableObjects) | |||
{ | |||
var name = item.Key; | |||
var saveable_object_proto = item.Value; | |||
var save_fn_id = saveable_object_proto.SaveFunction; | |||
var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||
} | |||
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// Load all saved objects. | |||
/// </summary> | |||
private void _load_nodes() | |||
{ | |||
// `nodes` maps from node ids to recreated objects | |||
// `node_setters` maps from node ids to setter functions | |||
// (same signature as setattr) for setting children. | |||
var (nodes, node_setters) = _initialize_loaded_nodes(); | |||
Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)> | |||
slot_variable_node_ids = new(); | |||
foreach(var (node_id, proto) in _iter_all_nodes()) | |||
{ | |||
foreach(var slot_variable_proto in proto.SlotVariables) | |||
{ | |||
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); | |||
} | |||
} | |||
// Re-create everything. | |||
foreach (var (node_id, proto) in _iter_all_nodes()) | |||
{ | |||
if (nodes.ContainsKey(node_id)) | |||
{ | |||
continue; | |||
} | |||
else if (slot_variable_node_ids.ContainsKey(node_id)) | |||
{ | |||
// Use the public Optimizer interface when creating slot variables. | |||
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||
var optimizer_object = nodes[optimizer_node_id]; | |||
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
// TODO: implement it. | |||
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
} | |||
else | |||
{ | |||
// skip the function and concrete function. | |||
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||
{ | |||
nodes[node_id] = null; | |||
node_setters[node_id] = null; | |||
continue; | |||
} | |||
var (node, setter) = _recreate(proto, node_id, nodes); | |||
nodes[node_id] = node; | |||
node_setters[node_id] = setter; | |||
} | |||
} | |||
if (!nodes.ContainsKey(0)) | |||
{ | |||
nodes[0] = _recreate_base_user_object().Item1; | |||
} | |||
_nodes = new List<Trackable>(); | |||
for(int i = 0; i < _proto.Nodes.Count; i++) | |||
{ | |||
_nodes.Add(nodes[i]); | |||
} | |||
_node_setters = node_setters; | |||
} | |||
/// <summary> | |||
/// Load state from checkpoint into the deserialized objects. | |||
/// </summary> | |||
private void _restore_checkpoint() | |||
{ | |||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
tf.device("CPU"); | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
LoadStatus load_status; | |||
if (_save_options.allow_partial_checkpoint) | |||
{ | |||
load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); | |||
load_status.assert_nontrivial_match(); | |||
} | |||
else | |||
{ | |||
load_status = saver.restore(variables_path, _checkpoint_options); | |||
load_status.assert_existing_objects_matched(); | |||
} | |||
var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + | |||
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
/// <summary> | |||
/// Adds edges from objects to other objects and functions. | |||
/// </summary> | |||
private void _load_edges() | |||
{ | |||
foreach(var (node_id, object_proto) in _iter_all_nodes()) | |||
{ | |||
_add_object_graph_edges(object_proto, node_id); | |||
} | |||
if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) | |||
{ | |||
var root = get(0); | |||
foreach(var node_path in _node_filters.Keys) | |||
{ | |||
var loaded_node = _nodes[_node_path_to_id[node_path]]; | |||
var path = node_path.Split('.'); | |||
var current_node = root; | |||
foreach(var name in path.Skip(1).Take(path.Length - 2)) | |||
{ | |||
// `hasattr` and `setattr` is used here | |||
throw new NotImplementedException(); | |||
} | |||
// `hasattr` and `setattr` is used here | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} | |||
private void _setup_remaining_functions() | |||
{ | |||
// TODO: implement it with concrete functions. | |||
} | |||
public Trackable get(int node_id) | |||
{ | |||
return _nodes[node_id]; | |||
} | |||
public Trackable get(string node_id) | |||
{ | |||
return get(_node_path_to_id[node_id]); | |||
} | |||
/// <summary> | |||
/// Adds edges from an object to its children. | |||
/// </summary> | |||
/// <param name="proto"></param> | |||
/// <param name="node_id"></param> | |||
private void _add_object_graph_edges(SavedObject proto, int node_id) | |||
{ | |||
var obj = _nodes[node_id]; | |||
var setter = _node_setters[node_id]; | |||
foreach(var refer in proto.Children) | |||
{ | |||
if(obj is null) | |||
{ | |||
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
continue; | |||
} | |||
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
// skip the process of "__call__" | |||
} | |||
} | |||
private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
{ | |||
Dictionary<int, Trackable> nodes = new(); | |||
Dictionary<int, Action<object, object, object>> node_setters = new(); | |||
foreach(var item in _loaded_nodes) | |||
{ | |||
var node_id = item.Key; | |||
var (node, setter) = item.Value; | |||
nodes[node_id] = node; | |||
node_setters[node_id] = setter; | |||
} | |||
return (nodes, node_setters); | |||
} | |||
private IEnumerable<(int, SavedObject)> _iter_all_nodes() | |||
{ | |||
foreach(var node_id in _ordered_node_ids) | |||
{ | |||
yield return (node_id, _proto.Nodes[node_id]); | |||
} | |||
} | |||
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
{ | |||
// skip the registered classes. | |||
Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||
foreach(var item in _get_node_dependencies(proto)) | |||
{ | |||
dependencies[item.Key] = nodes[item.Value]; | |||
} | |||
return _recreate_default(proto, node_id, dependencies); | |||
} | |||
/// <summary> | |||
/// Creates a Python object from a SavedObject protocol buffer. | |||
/// </summary> | |||
/// <param name="proto"></param> | |||
/// <param name="node_id"></param> | |||
/// <param name="dependencies"></param> | |||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
{ | |||
return proto.KindCase switch | |||
{ | |||
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||
SavedObject.KindOneofCase.Function => throw new NotImplementedException(), | |||
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||
}; | |||
} | |||
private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||
{ | |||
// skip the check of proto identifier because of lack of property. | |||
var looked_up = RevivedTypes.deserialize(proto); | |||
if(looked_up is null) | |||
{ | |||
return _recreate_base_user_object(proto, node_id); | |||
} | |||
return (looked_up.Item1, looked_up.Item2); | |||
} | |||
private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||
{ | |||
return (new _UserObject(), setattr); | |||
} | |||
private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto) | |||
{ | |||
string name = proto.Name; | |||
string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>"; | |||
// TODO(Rinne): `validate_synchronization_aggregation_trainable` | |||
var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( | |||
proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); | |||
var saved_device = proto.Device; | |||
var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); | |||
if (load_with_device) | |||
{ | |||
tf.device(saved_device); | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
} | |||
else | |||
{ | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
} | |||
} | |||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
{ | |||
throw new NotImplementedException(); | |||
//var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
} | |||
// TODO: remove this to a common class. | |||
public static Action<object, object, object> setattr = (x, y, z) => | |||
{ | |||
Debug.Assert(y is string); | |||
var properties = x.GetType().GetProperties(); | |||
foreach(var p in properties) | |||
{ | |||
if((string)y == p.Name) | |||
{ | |||
p.SetValue(x, z); | |||
return; | |||
} | |||
} | |||
// TODO(Rinne): check if the property has been set successfully. | |||
//throw new ValueError($"Cannot find the property {y} of {x}."); | |||
}; | |||
public class _UserObject: AutoTrackable | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,122 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public partial class Loader | |||
{ | |||
public static SavedModel parse_saved_model(string export_dir) | |||
{ | |||
var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||
var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); | |||
SavedModel saved_model = new SavedModel(); | |||
if (File.Exists(path_to_pb)) | |||
{ | |||
byte[] file_content; | |||
using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) | |||
{ | |||
file_content = new byte[f.Length]; | |||
Debug.Assert(f.Length <= int.MaxValue); | |||
f.Read(file_content, 0, (int)f.Length); | |||
} | |||
// TODO: change to stream mode. | |||
saved_model.MergeFrom(file_content); | |||
return saved_model; | |||
} | |||
else if (File.Exists(path_to_pbtxt)) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + | |||
$"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); | |||
} | |||
} | |||
// TODO: revise the type of `tags` | |||
public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) | |||
{ | |||
return load_partial(export_dir, null, tags, options)["root"]; | |||
} | |||
public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null) | |||
{ | |||
if (options is null) | |||
{ | |||
options = new LoadOptions(); | |||
} | |||
if (tags is not null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); | |||
Trackable root = null; | |||
Loader loader = null; | |||
if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) | |||
{ | |||
// skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` | |||
var meta_graph_def = saved_model_proto.MetaGraphs[0]; | |||
if (!BitConverter.IsLittleEndian) | |||
{ | |||
SavedModelUtils.swap_function_tensor_content(meta_graph_def); | |||
} | |||
var object_graph_proto = meta_graph_def.ObjectGraphDef; | |||
var ckpt_options = new CheckpointOptions(options.experimental_io_device); | |||
tf_with(ops.init_scope(), x => | |||
{ | |||
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||
root = loader.get(0); | |||
// skip the assignment of `graph_debug_info`. | |||
}); | |||
// skip the assignment of `tensorflow_version` | |||
// skip the assignment of `tensorflow_git_version` | |||
// skip the process of `metrics`. | |||
} | |||
else | |||
{ | |||
if(filters is not null && filters.Count > 0) | |||
{ | |||
throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" | |||
+ " version) cannot be loaded with node filters."); | |||
} | |||
tf_with(ops.init_scope(), x => | |||
{ | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
}); | |||
} | |||
if(filters != null && filters.Count > 0) | |||
{ | |||
return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||
} | |||
else | |||
{ | |||
var res = new Dictionary<string, Trackable>(); | |||
res["root"] = root; | |||
return res; | |||
} | |||
} | |||
public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) | |||
{ | |||
var saved_model = parse_saved_model(export_dir); | |||
// TODO: implement debug info. | |||
return (saved_model, null); | |||
} | |||
} | |||
} |
@@ -6,7 +6,6 @@ using System.Text; | |||
using Google.Protobuf; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Functions; | |||
using Tensorflow.ModelSaving; | |||
using Tensorflow.Train; | |||
using Tensorflow.Exceptions; | |||
using static Tensorflow.Binding; | |||
@@ -1,7 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.ModelSaving; | |||
namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
@@ -68,6 +68,34 @@ namespace Tensorflow | |||
return saveables.ToArray(); | |||
} | |||
public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables) | |||
{ | |||
var saveables = new List<MySaveableObject>(); | |||
var seen_ops = new List<Tensor>(); | |||
foreach (var (name, op) in enumerate(names_to_saveables)) | |||
{ | |||
foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) | |||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||
} | |||
return saveables.ToArray(); | |||
} | |||
public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables) | |||
{ | |||
var saveables = new List<MySaveableObject>(); | |||
var seen_ops = new List<BaseResourceVariable>(); | |||
foreach(var item in names_to_saveables.OrderBy(x => x.Key)) | |||
{ | |||
foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||
{ | |||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||
} | |||
} | |||
return saveables.ToArray(); | |||
} | |||
private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | |||
{ | |||
if (seen_ops.Contains(saveable.op)) | |||
@@ -77,6 +105,15 @@ namespace Tensorflow | |||
seen_ops.Add(saveable.op); | |||
} | |||
private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable) | |||
{ | |||
if (seen_ops.Contains(saveable.variable)) | |||
throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); | |||
saveables.Add(saveable); | |||
seen_ops.Add(saveable.variable); | |||
} | |||
/// <summary> | |||
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||
/// </summary> | |||
@@ -136,19 +173,20 @@ namespace Tensorflow | |||
{ | |||
full_name = name + "_" + attr; | |||
} | |||
if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||
var op = factory(full_name); | |||
if(op.TryGet<BaseResourceVariable>(out var variable)) | |||
{ | |||
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
{ | |||
yield return op; | |||
yield return v; | |||
} | |||
} | |||
else | |||
{ | |||
var saveable = factory.GetValue<MySaveableObject>(); | |||
foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||
var saveable = op.GetValue<MySaveableObject>(); | |||
foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||
{ | |||
yield return op; | |||
yield return v; | |||
} | |||
} | |||
} | |||
@@ -214,20 +252,19 @@ namespace Tensorflow | |||
return names_to_saveables; | |||
} | |||
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||
{ | |||
// skip the process of type `PythonState` | |||
if (trackable_has_serialize_to_tensor(obj)) | |||
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
{ | |||
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||
var tensor_dict = obj.serialize_to_tensors(); | |||
List<SaveSpec> specs = new(); | |||
List<string> local_names = new(); | |||
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||
foreach(var pair in tensor_dict) | |||
foreach (var pair in tensor_dict) | |||
{ | |||
var tensor_name = pair.Key; | |||
var maybe_tensor = pair.Value; | |||
@@ -235,9 +272,9 @@ namespace Tensorflow | |||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
IDictionary<string, Tensor> internal_dict; | |||
if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
{ | |||
internal_dict= new Dictionary<string, Tensor>(); | |||
internal_dict = new Dictionary<string, Tensor>(); | |||
internal_dict[""] = tensor; | |||
} | |||
else | |||
@@ -245,13 +282,18 @@ namespace Tensorflow | |||
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
} | |||
foreach(var item in internal_dict) | |||
foreach (var item in internal_dict) | |||
{ | |||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
} | |||
} | |||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
} | |||
if (trackable_has_serialize_to_tensor(obj)) | |||
{ | |||
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||
return res; | |||
} | |||
else | |||
@@ -333,6 +375,28 @@ namespace Tensorflow | |||
return restored_ops; | |||
}; | |||
} | |||
/// <summary> | |||
/// Returns a dict of SaveableObject factories generated from loaded fns. | |||
/// </summary> | |||
/// <param name="saveable_fn_by_name"></param> | |||
/// <param name="temp_session"></param> | |||
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||
{ | |||
if (saveable_fn_by_name.Count > 0) | |||
{ | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
return res; | |||
} | |||
public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
bool call_with_mapped_captures = false) | |||
{ | |||
return factory(key); | |||
} | |||
} | |||
public class SaveableCompatibilityConverter: Trackable | |||
@@ -20,8 +20,8 @@ using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.ModelSaving; | |||
using Tensorflow.Training; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Train | |||
@@ -41,9 +41,10 @@ namespace Tensorflow.Train | |||
protected IDictionary<string, Trackable> _unconditional_dependency_names; | |||
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||
protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||
new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
private bool _manual_tracking = true; | |||
private static Trackable _none = new AutoTrackable(); | |||
@@ -71,6 +72,18 @@ namespace Tensorflow.Train | |||
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||
public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||
public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||
{ | |||
get | |||
{ | |||
return _self_saveable_object_factories; | |||
} | |||
set | |||
{ | |||
_self_saveable_object_factories = value; | |||
} | |||
} | |||
/// <summary> | |||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
@@ -136,9 +149,11 @@ namespace Tensorflow.Train | |||
_self_update_uid = -1; | |||
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||
_unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||
_unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>(); | |||
} | |||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, | |||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
_maybe_initialize_trackable(); | |||
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||
@@ -174,10 +189,19 @@ namespace Tensorflow.Train | |||
/// <param name="trackable"></param> | |||
public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | |||
{ | |||
//_maybe_initialize_trackable(); | |||
//trackable._maybe_initialize_trackable(); | |||
// TODO: complete the implementation. | |||
_maybe_initialize_trackable(); | |||
trackable._maybe_initialize_trackable(); | |||
if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) | |||
{ | |||
_unconditional_deferred_dependencies.Remove(name); | |||
foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) | |||
{ | |||
checkpoint_position.restore(trackable); | |||
} | |||
} | |||
// TODO(Rinne): deal with `_self_name_based_restores` | |||
} | |||
public virtual Trackable? _lookup_dependency(string name) | |||
@@ -225,12 +249,19 @@ namespace Tensorflow.Train | |||
return self_tensor_map.Keys.ToList(); | |||
} | |||
public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
{ | |||
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
{ | |||
throw new NotImplementedException(); | |||
//return new TrackableSaveable(this, null, name, null, null); | |||
} | |||
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
{ | |||
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||
throw new NotImplementedException(); | |||
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
res[""] = create_saveable; | |||
return res; | |||
} | |||
else | |||
{ | |||
@@ -259,4 +290,6 @@ namespace Tensorflow.Train | |||
} | |||
public record class TrackableReference(string Name, Trackable Refer); | |||
public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); | |||
} |
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Exceptions; | |||
using Tensorflow.Train; | |||
@@ -20,9 +21,9 @@ public static class TrackableUtils | |||
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||
} | |||
} | |||
private static string _ESCAPE_CHAR = "."; | |||
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
internal static string _ESCAPE_CHAR = "."; | |||
internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||
{ | |||
@@ -5,9 +5,9 @@ using Tensorflow.Variables; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
using System.Collections.Generic; | |||
using Tensorflow.ModelSaving; | |||
using System.Diagnostics; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
namespace Tensorflow | |||
{ | |||
@@ -19,7 +19,11 @@ namespace Tensorflow | |||
protected TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
protected string _handle_name; | |||
protected string handle_name => _handle_name; | |||
public string handle_name | |||
{ | |||
get { return _handle_name; } | |||
set { _handle_name = value; } | |||
} | |||
protected string _unique_id; | |||
public string UniqueId => _unique_id; | |||
@@ -289,10 +293,10 @@ namespace Tensorflow | |||
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||
} | |||
public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
{ | |||
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||
return res; | |||
} | |||
@@ -238,5 +238,23 @@ namespace Tensorflow | |||
{ | |||
return _graph_element.eval(session); | |||
} | |||
public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( | |||
VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) | |||
{ | |||
if(aggregation is null) | |||
{ | |||
aggregation = VariableAggregation.None; | |||
} | |||
if(synchronization is null) | |||
{ | |||
synchronization = VariableSynchronization.Auto; | |||
} | |||
if (trainable is null) | |||
{ | |||
trainable = synchronization != VariableSynchronization.OnRead; | |||
} | |||
return (synchronization.Value, aggregation.Value, trainable.Value); | |||
} | |||
} | |||
} |
@@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
/// <param name="config"></param> | |||
/// <returns></returns> | |||
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||
{ | |||
// Layer instances created during the graph reconstruction process. | |||
var created_layers = new Dictionary<string, ILayer>(); | |||
created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||
var node_index_map = new Dictionary<(string, int), int>(); | |||
var node_count_by_layer = new Dictionary<ILayer, int>(); | |||
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||
@@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine | |||
layer = created_layers[layer_name]; | |||
else | |||
{ | |||
layer = layer_data.ClassName switch | |||
{ | |||
"InputLayer" => InputLayer.from_config(layer_data.Config), | |||
"Dense" => Dense.from_config(layer_data.Config), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); | |||
created_layers[layer_name] = layer; | |||
} | |||
@@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine | |||
Inputs = inputs, | |||
Outputs = outputs | |||
}) | |||
{ | |||
Initialize(inputs, outputs, name); | |||
} | |||
internal void Initialize(Tensors inputs, Tensors outputs, string name = null) | |||
{ | |||
_input_layers = new List<ILayer>(); | |||
_output_layers = new List<ILayer>(); | |||
@@ -70,7 +75,14 @@ namespace Tensorflow.Keras.Engine | |||
this.inputs = inputs; | |||
this.outputs = outputs; | |||
built = true; | |||
_buildInputShape = inputs.shape; | |||
if(inputs.Length > 0) | |||
{ | |||
_buildInputShape = inputs.shape; | |||
} | |||
else | |||
{ | |||
_buildInputShape = new Saving.TensorShapeConfig(); | |||
} | |||
if (outputs.Any(x => x.KerasHistory == null)) | |||
base_layer_utils.create_keras_history(outputs); | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine | |||
public virtual Shape ComputeOutputShape(Shape input_shape) | |||
=> throw new NotImplementedException(""); | |||
protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) | |||
{ | |||
List<IVariableV1> res = new(); | |||
var nested_layers = _flatten_layers(false, false); | |||
foreach (var layer in nested_layers) | |||
{ | |||
if (layer is Layer l) | |||
{ | |||
if (include_trainable == true && include_non_trainable == true) | |||
{ | |||
res.AddRange(l.Variables); | |||
} | |||
else if (include_trainable == true && include_non_trainable == false) | |||
{ | |||
res.AddRange(l.TrainableVariables); | |||
} | |||
else if(include_trainable == false && include_non_trainable == true) | |||
{ | |||
res.AddRange(l.NonTrainableVariables); | |||
} | |||
} | |||
} | |||
return res; | |||
} | |||
} | |||
} |
@@ -12,7 +12,7 @@ public abstract partial class Layer | |||
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||
public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
@@ -66,16 +67,74 @@ namespace Tensorflow.Keras.Engine | |||
public bool SupportsMasking { get; set; } | |||
protected List<IVariableV1> _trainable_weights; | |||
public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||
public virtual List<IVariableV1> TrainableVariables => TrainableWeights; | |||
protected List<IVariableV1> _non_trainable_weights; | |||
public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | |||
public List<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||
public List<IVariableV1> Variables => Weights; | |||
public virtual List<IVariableV1> TrainableWeights | |||
{ | |||
get | |||
{ | |||
if (!this.Trainable) | |||
{ | |||
return new List<IVariableV1>(); | |||
} | |||
var children_weights = _gather_children_variables(true); | |||
return children_weights.Concat(_trainable_weights).Distinct().ToList(); | |||
} | |||
} | |||
public virtual List<IVariableV1> NonTrainableWeights | |||
{ | |||
get | |||
{ | |||
if (!this.Trainable) | |||
{ | |||
var children_weights = _gather_children_variables(true, true); | |||
return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); | |||
} | |||
else | |||
{ | |||
var children_weights = _gather_children_variables(include_non_trainable: true); | |||
return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); | |||
} | |||
} | |||
} | |||
public virtual List<IVariableV1> Weights | |||
{ | |||
get | |||
{ | |||
return TrainableWeights.Concat(NonTrainableWeights).ToList(); | |||
} | |||
set | |||
{ | |||
if (Weights.Count() != value.Count()) throw new ValueError( | |||
$"You called `set_weights` on layer \"{this.name}\"" + | |||
$"with a weight list of length {len(value)}, but the layer was " + | |||
$"expecting {len(Weights)} weights."); | |||
foreach (var (this_w, v_w) in zip(Weights, value)) | |||
this_w.assign(v_w, read_value: true); | |||
} | |||
} | |||
protected int id; | |||
public int Id => id; | |||
protected string name; | |||
protected string base_name; | |||
public string Name => name; | |||
public string Name | |||
{ | |||
get | |||
{ | |||
return name; | |||
} | |||
set | |||
{ | |||
name = value; | |||
} | |||
} | |||
protected bool computePreviousMask; | |||
protected List<Operation> updates; | |||
@@ -85,10 +144,11 @@ namespace Tensorflow.Keras.Engine | |||
List<INode> inboundNodes; | |||
public List<INode> InboundNodes => inboundNodes; | |||
List<INode> outboundNodes; | |||
public List<INode> OutboundNodes => outboundNodes; | |||
public JObject SerializedAttributes { get; set; } | |||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
public CallContext CallContext => callContext.Value; | |||
public Tensor[] input | |||
@@ -117,6 +177,11 @@ namespace Tensorflow.Keras.Engine | |||
protected List<ILayer> _self_tracked_trackables; | |||
public Layer(LayerArgs args) | |||
{ | |||
Initialize(args); | |||
} | |||
internal virtual void Initialize(LayerArgs args) | |||
{ | |||
this.args = args; | |||
// A stateful layer is a layer whose updates are run during inference too, | |||
@@ -273,46 +338,9 @@ namespace Tensorflow.Keras.Engine | |||
public int count_params() | |||
{ | |||
if (Trainable) | |||
return layer_utils.count_params(this, weights); | |||
return layer_utils.count_params(this, Weights); | |||
return 0; | |||
} | |||
List<IVariableV1> ILayer.TrainableWeights | |||
{ | |||
get | |||
{ | |||
return _trainable_weights; | |||
} | |||
} | |||
List<IVariableV1> ILayer.NonTrainableWeights | |||
{ | |||
get | |||
{ | |||
return _non_trainable_weights; | |||
} | |||
} | |||
public List<IVariableV1> weights | |||
{ | |||
get | |||
{ | |||
var weights = new List<IVariableV1>(); | |||
weights.AddRange(_trainable_weights); | |||
weights.AddRange(_non_trainable_weights); | |||
return weights; | |||
} | |||
set | |||
{ | |||
if (weights.Count() != value.Count()) throw new ValueError( | |||
$"You called `set_weights` on layer \"{this.name}\"" + | |||
$"with a weight list of length {len(value)}, but the layer was " + | |||
$"expecting {len(weights)} weights."); | |||
foreach (var (this_w, v_w) in zip(weights, value)) | |||
this_w.assign(v_w, read_value: true); | |||
} | |||
} | |||
public List<IVariableV1> Variables => weights; | |||
public virtual IKerasConfig get_config() | |||
=> args; | |||
@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
using (SharedObjectSavingScope.Enter()) | |||
{ | |||
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
} | |||
} | |||
} | |||
@@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine | |||
IVariableV1 _predict_counter; | |||
bool _base_model_initialized; | |||
bool stop_training; | |||
public bool IsGraphNetwork => _is_graph_network; | |||
public OptimizerV2 Optimizer | |||
{ | |||
@@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine | |||
_init_batch_counters(); | |||
} | |||
internal override void Initialize(LayerArgs args) | |||
{ | |||
_init_batch_counters(); | |||
base.Initialize(args); | |||
} | |||
void _configure_steps_per_execution(int steps_per_execution) | |||
{ | |||
_steps_per_execution = tf.Variable(steps_per_execution, | |||
@@ -81,10 +89,11 @@ namespace Tensorflow.Keras.Engine | |||
public override List<ILayer> Layers | |||
=> _flatten_layers(recursive: false, include_self: false).ToList(); | |||
public override List<IVariableV1> TrainableVariables | |||
public override List<IVariableV1> TrainableWeights | |||
{ | |||
get | |||
{ | |||
// skip the assertion of weights created. | |||
var variables = new List<IVariableV1>(); | |||
if (!Trainable) | |||
@@ -95,18 +104,40 @@ namespace Tensorflow.Keras.Engine | |||
foreach (var trackable_obj in _self_tracked_trackables) | |||
{ | |||
if (trackable_obj.Trainable) | |||
variables.AddRange(trackable_obj.TrainableVariables); | |||
variables.AddRange(trackable_obj.TrainableWeights); | |||
} | |||
foreach (var layer in _self_tracked_trackables) | |||
variables.AddRange(_trainable_weights); | |||
return variables.Distinct().ToList(); | |||
} | |||
} | |||
public override List<IVariableV1> NonTrainableWeights | |||
{ | |||
get | |||
{ | |||
// skip the assertion of weights created. | |||
var variables = new List<IVariableV1>(); | |||
foreach (var trackable_obj in _self_tracked_trackables) | |||
{ | |||
if (layer.Trainable) | |||
variables.AddRange(layer.TrainableVariables); | |||
variables.AddRange(trackable_obj.NonTrainableWeights); | |||
} | |||
// variables.AddRange(_trainable_weights); | |||
if (!Trainable) | |||
{ | |||
var trainable_variables = new List<IVariableV1>(); | |||
foreach (var trackable_obj in _self_tracked_trackables) | |||
{ | |||
variables.AddRange(trackable_obj.TrainableWeights); | |||
} | |||
variables.AddRange(trainable_variables); | |||
variables.AddRange(_trainable_weights); | |||
variables.AddRange(_non_trainable_weights); | |||
} | |||
return variables; | |||
return variables.Distinct().ToList(); | |||
} | |||
} | |||
@@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine | |||
: base(args.Inputs, args.Outputs, name: args.Name) | |||
{ | |||
this.args = args; | |||
if (args.Layers == null) | |||
args.Layers = new List<ILayer>(); | |||
// SupportsMasking = true; | |||
_compute_output_and_mask_jointly = true; | |||
_auto_track_sub_layers = false; | |||
@@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine | |||
_created_nodes = new List<INode>(); | |||
// Add to the model any layers passed to the constructor. | |||
if (args.Layers != null) | |||
if (args.Layers is not null) | |||
{ | |||
foreach (var layer in args.Layers) | |||
add(layer); | |||
InitLayers(args.Layers); | |||
} | |||
} | |||
public void InitLayers(IEnumerable<ILayer> layers) | |||
{ | |||
foreach(var layer in layers) | |||
{ | |||
add(layer); | |||
} | |||
} | |||
@@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||
{ | |||
throw new ValueError("Alpha must be a number greater than 0."); | |||
} | |||
_buildInputShape = input_shape; | |||
built = true; | |||
base.build(input_shape); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
@@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||
} | |||
public override void build(Shape input_shape) | |||
{ | |||
_buildInputShape = input_shape; | |||
built = true; | |||
base.build(input_shape); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
@@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { | |||
if ( alpha < 0f ) { | |||
throw new ValueError("Alpha must be a number greater than 0."); | |||
} | |||
_buildInputShape = input_shape; | |||
built = true; | |||
base.build(input_shape); | |||
} | |||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
Tensor output = inputs; | |||
@@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers | |||
return outputs; | |||
} | |||
public static Dense from_config(LayerArgs args) | |||
{ | |||
return new Dense(args as DenseArgs); | |||
} | |||
} | |||
} |
@@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers | |||
name: Name); | |||
} | |||
public static InputLayer from_config(LayerArgs args) | |||
{ | |||
return new InputLayer(args as InputLayerArgs); | |||
} | |||
public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | |||
} | |||
} |
@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics | |||
public virtual void reset_states() | |||
{ | |||
foreach (var v in weights) | |||
foreach (var v in Weights) | |||
v.assign(0); | |||
} | |||
@@ -4,6 +4,7 @@ using System.IO; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
namespace Tensorflow.Keras.Models | |||
@@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models | |||
public Functional from_config(ModelConfig config) | |||
=> Functional.from_config(config); | |||
public void load_model(string filepath, bool compile = true) | |||
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
{ | |||
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||
var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||
var meta_graph_def = saved_mode.MetaGraphs[0]; | |||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||
var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||
// Recreate layers and metrics using the info stored in the metadata. | |||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
keras_loader.load_layers(compile: compile); | |||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||
} | |||
} | |||
} |
@@ -1,12 +1,24 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Reflection; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Layers.Rnn; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
using static Tensorflow.ApiDef.Types; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
@@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving | |||
{ | |||
public class KerasObjectLoader | |||
{ | |||
SavedMetadata _metadata; | |||
SavedObjectGraph _proto; | |||
Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||
List<int> _traversed_nodes_from_config = new List<int>(); | |||
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
private SavedMetadata _metadata; | |||
private SavedObjectGraph _proto; | |||
private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>(); | |||
private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>(); | |||
private List<int> _traversed_nodes_from_config = new List<int>(); | |||
private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes; | |||
private List<int> _models_to_reconstruct; | |||
public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes; | |||
static KerasObjectLoader() | |||
{ | |||
PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||
} | |||
public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | |||
{ | |||
_metadata = metadata; | |||
_proto = object_graph_def; | |||
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | |||
_models_to_reconstruct = new List<int>(); | |||
loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
} | |||
/// <summary> | |||
@@ -42,15 +66,255 @@ namespace Tensorflow.Keras.Saving | |||
continue; | |||
} | |||
_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
} | |||
foreach(var node_metadata in metric_list) | |||
{ | |||
try | |||
{ | |||
if (node_metadata.Identifier.Equals("_tf_keras_metric")) | |||
{ | |||
continue; | |||
} | |||
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | |||
node_metadata.Metadata); | |||
} | |||
catch(ValueError e) | |||
{ | |||
if (compile) | |||
{ | |||
throw e; | |||
} | |||
// TODO: add logging.warning. | |||
} | |||
} | |||
} | |||
public string get_path(int node_id) | |||
{ | |||
return _node_paths[node_id]; | |||
} | |||
/// <summary> | |||
/// Finish setting up Keras objects. | |||
/// | |||
/// This function is executed after all objects and functions have been created. | |||
/// Call functions and losses are attached to each layer, and once all layers | |||
/// have been fully set up, graph networks are initialized. | |||
/// | |||
/// Subclassed models that are revived from the SavedModel are treated like | |||
/// layers, and have their call/loss functions attached here. | |||
/// </summary> | |||
public void finalize_objects() | |||
{ | |||
List<Layer> layers_revived_from_config = new(); | |||
List<Layer> layers_revived_from_saved_model = new(); | |||
foreach(var item in loaded_nodes) | |||
{ | |||
var node_id = item.Key; | |||
var node = item.Value.Item1; | |||
if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) | |||
{ | |||
continue; | |||
} | |||
_unblock_model_reconstruction(node_id, node as Layer); | |||
if(node is InputLayer or Metric) | |||
{ | |||
continue; | |||
} | |||
// TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||
layers_revived_from_config.Add(node as Layer); | |||
} | |||
_finalize_saved_model_layers(layers_revived_from_saved_model); | |||
_finalize_config_layers(layers_revived_from_config); | |||
_reconstruct_all_models(); | |||
} | |||
private void _reconstruct_all_models() | |||
{ | |||
HashSet<int> all_initialized_models = new(); | |||
for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) | |||
{ | |||
int model_id = _models_to_reconstruct[i]; | |||
all_initialized_models.Add(model_id); | |||
var (model, layers) = model_layer_dependencies[model_id]; | |||
_reconstruct_model(model_id, model, layers.ToList()); | |||
_finalize_config_layers(new List<Layer>() { model }); | |||
} | |||
Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); | |||
} | |||
private void _reconstruct_model(int model_id, Model model, List<Layer> layers) | |||
{ | |||
var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"]; | |||
if(model.input is not null && model.input.Length > 0) | |||
{ | |||
} | |||
else if(model is Sequential s) | |||
{ | |||
if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) | |||
{ | |||
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||
{ | |||
layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||
} | |||
else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||
{ | |||
// TODO(Rinne): implement it | |||
} | |||
} | |||
// `model.__init__(layers, config["name"])` | |||
s.InitLayers(layers); | |||
s.Name = config["name"].ToObject<string>(); | |||
if(s.input is null || s.input.Length == 0) | |||
{ | |||
var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
var input_specs = _infer_inputs(first_layer); | |||
var input_shapes = _infer_inputs(first_layer, true); | |||
// `model._set_inputs(input_specs)` | |||
// skip the check of input_specs is Dictionary | |||
if (!s.Built) | |||
{ | |||
s.build(input_shapes); | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
// skip the parameter `created_layers`. | |||
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), | |||
layers.ToDictionary(x => x.Name, x => x as ILayer)); | |||
// skip the `model.__init__` | |||
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||
(model as Functional).connect_ancillary_layers(created_layers); | |||
} | |||
_set_network_attributes_from_metadata(model); | |||
_unblock_model_reconstruction(model_id, model); | |||
} | |||
private void _set_network_attributes_from_metadata(Model revived_object) | |||
{ | |||
// TODO: implement it. | |||
} | |||
/// <summary> | |||
/// Runs the final steps of loading Keras Layers from config. | |||
/// </summary> | |||
/// <param name="layers"></param> | |||
private void _finalize_config_layers(List<Layer> layers) | |||
{ | |||
foreach(var layer in layers) | |||
{ | |||
if (_is_graph_network(layer)) | |||
{ | |||
_restore_layer_unconditional_losses(layer); | |||
} | |||
_restore_layer_activation_loss(layer); | |||
_restore_layer_metrics(layer); | |||
// TODO(Rinne): deal with RNN. | |||
} | |||
} | |||
/// <summary> | |||
/// Runs the final steps of loading Keras Layers from SavedModel. | |||
/// </summary> | |||
/// <param name="layers"></param> | |||
private void _finalize_saved_model_layers(List<Layer> layers) | |||
{ | |||
foreach(var layer in layers) | |||
{ | |||
// TODO(Rinne): deal with `RevivedNetwork`. | |||
_restore_layer_unconditional_losses(layer); | |||
_restore_layer_activation_loss(layer); | |||
_restore_layer_metrics(layer); | |||
} | |||
} | |||
private void _restore_layer_unconditional_losses(Layer layer) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
private void _restore_layer_activation_loss(Layer layer) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
private void _restore_layer_metrics(Layer layer) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
/// <summary> | |||
/// Removes layer from blocking model reconstruction. | |||
/// </summary> | |||
/// <param name="layer_id"></param> | |||
/// <param name="layer"></param> | |||
private void _unblock_model_reconstruction(int layer_id, Layer layer) | |||
{ | |||
foreach(var depencency in model_layer_ids_dependencies) | |||
{ | |||
var layer_ids = depencency.Value.Item2; | |||
var layers = model_layer_dependencies.SetDefault(depencency.Key, | |||
(depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; | |||
if (!layer_ids.Contains(layer_id)) | |||
{ | |||
continue; | |||
} | |||
layers[Array.IndexOf(layer_ids, layer_id)] = layer; | |||
if (layers.All(x => x is not null)) | |||
{ | |||
_models_to_reconstruct.Add(depencency.Key); | |||
} | |||
} | |||
} | |||
void _load_layer(int node_id, string identifier, string metadata_json) | |||
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
{ | |||
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
_revive_from_config(identifier, metadata, node_id); | |||
if (loaded_nodes.ContainsKey(node_id)) | |||
{ | |||
var (node, setter) = loaded_nodes[node_id]; | |||
_maybe_add_serialized_attributes(node as Layer, metadata); | |||
var config = metadata.Config; | |||
if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) | |||
{ | |||
Debug.Assert(node is Model); | |||
var child_nodes = _get_child_layer_node_ids(node_id); | |||
model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); | |||
if(child_nodes is null || child_nodes.Length == 0) | |||
{ | |||
_models_to_reconstruct.Add(node_id); | |||
} | |||
} | |||
return (node, setter); | |||
} | |||
else | |||
{ | |||
var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||
if (obj is null) | |||
{ | |||
(obj, setter) = _revive_custom_object(identifier, metadata); | |||
} | |||
Debug.Assert(obj is Layer); | |||
_maybe_add_serialized_attributes(obj as Layer, metadata); | |||
return (obj, setter); | |||
} | |||
} | |||
/// <summary> | |||
@@ -59,11 +323,34 @@ namespace Tensorflow.Keras.Saving | |||
/// <param name="identifier"></param> | |||
/// <param name="metadata"></param> | |||
/// <param name="node_id"></param> | |||
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
{ | |||
var obj = _revive_graph_network(identifier, metadata, node_id); | |||
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
Trackable obj; | |||
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||
{ | |||
// TODO(Rinne): implement it. | |||
return (null, null); | |||
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
} | |||
else | |||
{ | |||
obj = _revive_graph_network(identifier, metadata, node_id); | |||
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
} | |||
if(obj is null) | |||
{ | |||
return (null, null); | |||
} | |||
var setter = _config_node_setter(_revive_setter); | |||
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | |||
return (obj, setter); | |||
} | |||
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
} | |||
Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||
@@ -71,6 +358,12 @@ namespace Tensorflow.Keras.Saving | |||
var config = metadata.Config; | |||
var class_name = metadata.ClassName; | |||
Model model = null; | |||
if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") | |||
{ | |||
return null; | |||
} | |||
if (class_name == "Sequential") | |||
{ | |||
model = new Sequential(new SequentialArgs | |||
@@ -78,34 +371,82 @@ namespace Tensorflow.Keras.Saving | |||
Name = config.GetValue("name").ToString() | |||
}); | |||
} | |||
else if (class_name == "Functional") | |||
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | |||
{ | |||
throw new NotImplementedException(""); | |||
model = new Sequential(new SequentialArgs | |||
{ | |||
Name = class_name | |||
}); | |||
} | |||
else | |||
{ | |||
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||
} | |||
if (!metadata.IsGraphNetwork) | |||
return null; | |||
// Record this model and its layers. This will later be used to reconstruct | |||
// the model. | |||
var layers = _get_child_layer_node_ids(node_id); | |||
model_layer_dependencies[node_id] = (model, layers); | |||
model_layer_ids_dependencies[node_id] = (model, layers); | |||
if(layers is null || layers.Length == 0) | |||
{ | |||
_models_to_reconstruct.Add(node_id); | |||
} | |||
return model; | |||
} | |||
Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
{ | |||
var config = metadata.Config; | |||
var class_name = metadata.ClassName; | |||
var shared_object_id = metadata.SharedObjectId; | |||
var must_restore_from_config = metadata.MustRestoreFromConfig; | |||
var obj = class_name switch | |||
{ | |||
"Resizing" => Resizing.from_config(config), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
var obj = generic_utils.deserialize_keras_object(class_name, config); | |||
obj.Name = metadata.Name; | |||
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||
var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
return null; | |||
if (!built) | |||
{ | |||
return null; | |||
} | |||
return obj; | |||
} | |||
private void _revive_setter(object layer, object name, object value) | |||
{ | |||
Debug.Assert(name is string); | |||
Debug.Assert(layer is Layer); | |||
if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
{ | |||
if(value is Trackable) | |||
{ | |||
(layer as Layer)._track_trackable(value as Trackable, name as string); | |||
} | |||
if((layer as Layer).SerializedAttributes is null) | |||
{ | |||
(layer as Layer).SerializedAttributes = new JObject(); | |||
} | |||
(layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||
} | |||
else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
{ | |||
(layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||
} | |||
else | |||
{ | |||
var properties = layer.GetType().GetProperties(); | |||
foreach(var p in properties) | |||
{ | |||
if(p.Name == name as string && p.GetValue(layer) is not null) | |||
{ | |||
return; | |||
} | |||
} | |||
Loader.setattr(layer, name, value); | |||
} | |||
} | |||
/// <summary> | |||
@@ -143,34 +484,186 @@ namespace Tensorflow.Keras.Saving | |||
/// <param name="obj"></param> | |||
/// <param name="proto"></param> | |||
/// <param name="node_id"></param> | |||
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||
void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) | |||
{ | |||
if (_traversed_nodes_from_config.Contains(node_id)) | |||
return; | |||
var parent_path = _node_paths[node_id]; | |||
_traversed_nodes_from_config.Add(node_id); | |||
if (!obj.Built) | |||
obj._maybe_initialize_trackable(); | |||
if(obj is Layer layer && !layer.Built) | |||
{ | |||
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
_try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata); | |||
_try_build_layer(layer, node_id, metadata.BuildInputShape); | |||
} | |||
List<(Trackable, int, string)> children = new(); | |||
foreach(var refer in proto.Children) | |||
{ | |||
var obj_child = obj._lookup_dependency(refer.LocalName); | |||
children.Add((obj_child, refer.NodeId, refer.LocalName)); | |||
} | |||
var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||
Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||
}); | |||
if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||
{ | |||
var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); | |||
foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) | |||
{ | |||
if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) | |||
{ | |||
var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; | |||
children.Add((metric as Metric, refer.NodeId, metric_path)); | |||
} | |||
} | |||
} | |||
foreach(var (obj_child, child_id, child_name) in children) | |||
{ | |||
if(obj_child is null) | |||
{ | |||
continue; | |||
} | |||
var child_proto = _proto.Nodes[child_id]; | |||
// skip the check for registered identifier | |||
Action<object, object, object> setter; | |||
if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||
{ | |||
setter = _revive_setter; | |||
} | |||
else | |||
{ | |||
setter = Loader.setattr; | |||
} | |||
if (loaded_nodes.ContainsKey(child_id)) | |||
{ | |||
// skip the logging.warning | |||
continue; | |||
} | |||
if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) | |||
{ | |||
(obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; | |||
} | |||
if(obj_child is TrackableDataStructure) | |||
{ | |||
setter = (x, y, z) => { }; | |||
} | |||
var child_path = $"{parent_path}.{child_name}"; | |||
_node_paths[child_id] = child_path; | |||
_add_children_recreated_from_config(obj_child, child_proto, child_id); | |||
loaded_nodes[child_id] = (obj_child, setter); | |||
} | |||
} | |||
bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) | |||
private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
{ | |||
if (obj.Built) | |||
return true; | |||
if(build_input_shape is null) | |||
{ | |||
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||
} | |||
if(build_input_shape is not null) | |||
{ | |||
obj.build(build_input_shape); | |||
// In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. | |||
// On the one hand, C# does not support call a method from specified parent class. | |||
// On the other hand, currently All class derived from Layer call `Layer.Build` or | |||
// move the implementation of `Layer.build` to its own `build` method. | |||
// Therefore we do not call it here. | |||
// However, it's still quite risky once in the future a certain class derived from | |||
// `Layer` does not call `Layer.build`. | |||
return true; | |||
} | |||
return false; | |||
} | |||
bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
/// <summary> | |||
/// Infers input shape of layer from SavedModel functions. | |||
/// </summary> | |||
/// <param name="layer_node_id"></param> | |||
/// <param name="convert_to_shapes"></param> | |||
/// <returns></returns> | |||
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||
{ | |||
if (obj.Built) | |||
return true; | |||
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||
if(call_fn_id is null) | |||
{ | |||
return null; | |||
} | |||
var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; | |||
if(concrete_functions is null) | |||
{ | |||
return null; | |||
} | |||
var call_fn_name = concrete_functions[0]; | |||
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
} | |||
private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
{ | |||
if(path_to_child is null || path_to_child.Count() == 0) | |||
{ | |||
return parent_id; | |||
} | |||
foreach(var child in _proto.Nodes[parent_id].Children) | |||
{ | |||
if(child.LocalName == path_to_child.First()) | |||
{ | |||
return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); | |||
} | |||
} | |||
return null; | |||
} | |||
private bool _is_graph_network(Layer layer) | |||
{ | |||
// TODO: deal with `RevivedLayer` | |||
if(layer is Functional) | |||
{ | |||
return (layer as Functional).IsGraphNetwork || layer is Sequential; | |||
} | |||
return false; | |||
} | |||
private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||
{ | |||
// TODO: deal with `RevivedLayer` | |||
} | |||
/// <summary> | |||
/// Creates edges for nodes that are recreated from config. | |||
/// </summary> | |||
/// <returns></returns> | |||
private Action<object, object, object> _config_node_setter(Action<object, object, object> setter) | |||
{ | |||
void setattr_wrapper(object obj, object name, object value) | |||
{ | |||
Debug.Assert(obj is Trackable); | |||
Debug.Assert(name is string); | |||
if((obj as Trackable)._lookup_dependency(name as string) is null) | |||
{ | |||
setter(obj, name, value); | |||
} | |||
} | |||
return setattr_wrapper; | |||
} | |||
} | |||
} |
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
public partial class KerasSavedModelUtils | |||
{ | |||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
SaveOptions? options, bool save_traces = true) | |||
{ | |||
if (!overwrite && File.Exists(filepath)) | |||
@@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils | |||
BadConsumers = { } | |||
}, | |||
Identifier = layer.ObjectIdentifier, | |||
Metadata = layer.TrackingMetadata | |||
Metadata = layer.GetTrackingMetadata() | |||
}; | |||
metadata.Nodes.Add(saved_object); | |||
@@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils | |||
if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
})); | |||
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||
{ | |||
if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
@@ -0,0 +1,96 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Train; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
public class KerasLoadModelUtils | |||
{ | |||
/// <summary> | |||
/// Corresponding to keras/saving/save.py/load_model | |||
/// </summary> | |||
/// <param name="filepath"></param> | |||
/// <param name="custom_objects"></param> | |||
/// <param name="compile"></param> | |||
/// <param name="options"></param> | |||
/// <returns></returns> | |||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||
bool compile = true, LoadOptions? options = null) | |||
{ | |||
using (SharedObjectSavingScope.Enter()) | |||
{ | |||
using (LoadContext.load_context(options)) | |||
{ | |||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||
{ | |||
throw new IOException($"No file or directory found at {filepath}."); | |||
} | |||
if (Directory.Exists(filepath)) | |||
{ | |||
return load(filepath, compile, options); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||
} | |||
} | |||
} | |||
} | |||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
{ | |||
SavedMetadata metadata = new SavedMetadata(); | |||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||
if (File.Exists(path_to_metadata_pb)) | |||
{ | |||
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
} | |||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||
{ | |||
return Loader.load(path, options: options) as Model; | |||
} | |||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
keras_loader.load_layers(compile: compile); | |||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||
nodes_to_load["root"] = (null, null); | |||
foreach(var item in keras_loader.LoadedNodes) | |||
{ | |||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||
} | |||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||
keras_loader.finalize_objects(); | |||
// keras_loader.del_tracking(); | |||
var model = loaded["root"]; | |||
if(model is Model && compile) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
// TODO(Rinne): implement it. | |||
} | |||
return model; | |||
} | |||
} | |||
} |
@@ -0,0 +1,69 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Threading; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
// TODO: remove this class to common project. | |||
public class ContextHandler: IDisposable | |||
{ | |||
public Action<bool> DisposeCallBack { get; set; } | |||
public void Dispose() | |||
{ | |||
DisposeCallBack.Invoke(true); | |||
} | |||
} | |||
public class LoadContext | |||
{ | |||
private bool _entered_load_context; | |||
private LoadOptions? _load_options; | |||
private static ThreadLocal<LoadContext> _load_context = new(); | |||
private LoadContext() | |||
{ | |||
_entered_load_context = false; | |||
_load_options = null; | |||
} | |||
public void set_load_options(LoadOptions load_options) | |||
{ | |||
_load_options = load_options; | |||
_entered_load_context = true; | |||
} | |||
private void clear_load_options() | |||
{ | |||
_load_options = null; | |||
_entered_load_context = false; | |||
} | |||
private LoadOptions? load_options() | |||
{ | |||
return _load_options; | |||
} | |||
public static ContextHandler load_context(LoadOptions? load_options) | |||
{ | |||
if(_load_context.Value is null) | |||
{ | |||
_load_context.Value = new LoadContext(); | |||
} | |||
_load_context.Value.set_load_options(load_options); | |||
return new ContextHandler() | |||
{ | |||
DisposeCallBack = _ => _load_context.Value.clear_load_options() | |||
}; | |||
} | |||
public static LoadOptions? get_load_option() | |||
{ | |||
return _load_context.Value.load_options(); | |||
} | |||
public static bool in_load_context() | |||
{ | |||
return _load_context.Value._entered_load_context; | |||
} | |||
} | |||
} |
@@ -19,15 +19,21 @@ using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Data; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Reflection; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
public class generic_utils | |||
{ | |||
private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; | |||
/// <summary> | |||
/// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | |||
/// </summary> | |||
@@ -51,6 +57,58 @@ namespace Tensorflow.Keras.Utils | |||
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | |||
} | |||
public static Layer deserialize_keras_object(string class_name, JToken config) | |||
{ | |||
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
var args = deserializationGenericMethod.Invoke(config, null); | |||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
Debug.Assert(layer is Layer); | |||
return layer as Layer; | |||
} | |||
public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||
{ | |||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
Debug.Assert(layer is Layer); | |||
return layer as Layer; | |||
} | |||
public static LayerArgs deserialize_layer_args(string class_name, JToken config) | |||
{ | |||
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
var args = deserializationGenericMethod.Invoke(config, null); | |||
Debug.Assert(args is LayerArgs); | |||
return args as LayerArgs; | |||
} | |||
public static ModelConfig deserialize_model_config(JToken json) | |||
{ | |||
ModelConfig config = new ModelConfig(); | |||
config.Name = json["name"].ToObject<string>(); | |||
config.Layers = new List<LayerConfig>(); | |||
var layersToken = json["layers"]; | |||
foreach (var token in layersToken) | |||
{ | |||
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||
config.Layers.Add(new LayerConfig() | |||
{ | |||
Config = args, | |||
Name = token["name"].ToObject<string>(), | |||
ClassName = token["class_name"].ToObject<string>(), | |||
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||
}); | |||
} | |||
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||
config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>(); | |||
return config; | |||
} | |||
public static string to_snake_case(string name) | |||
{ | |||
return string.Concat(name.Select((x, i) => | |||
@@ -60,5 +118,15 @@ namespace Tensorflow.Keras.Utils | |||
x.ToString(); | |||
})).ToLower(); | |||
} | |||
/// <summary> | |||
/// Determines whether config appears to be a valid layer config. | |||
/// </summary> | |||
/// <param name="config"></param> | |||
/// <returns></returns> | |||
public static bool validate_config(JObject config) | |||
{ | |||
return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); | |||
} | |||
} | |||
} |
@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils | |||
} | |||
var trainable_count = count_params(model, model.TrainableVariables); | |||
var non_trainable_count = count_params(model, model.non_trainable_variables); | |||
var non_trainable_count = count_params(model, model.NonTrainableVariables); | |||
print($"Total params: {trainable_count + non_trainable_count}"); | |||
print($"Trainable params: {trainable_count}"); | |||
@@ -0,0 +1,9 @@ | |||
´$root"_tf_keras_network*’${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2 | |||
†root.layer-0"_tf_keras_input_layer*Ö{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2 | |||
Íroot.layer-1"_tf_keras_layer*£{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2 | |||
¯root.layer_with_weights-0"_tf_keras_layer*ø{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 | |||
²root.layer_with_weights-1"_tf_keras_layer*û{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2 | |||
Šroot.layer-4"_tf_keras_layer*à{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 | |||
¹Troot.keras_api.metrics.0"_tf_keras_metric*‚{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2 | |||
™Uroot.keras_api.metrics.1"_tf_keras_metric*â{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2 |
@@ -0,0 +1,68 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Optimizers; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow.NumPy; | |||
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; | |||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
[TestClass] | |||
public class SequentialModelLoad | |||
{ | |||
[TestMethod] | |||
public void SimpleModelFromAutoCompile() | |||
{ | |||
var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile"); | |||
model.summary(); | |||
model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
// check the weights | |||
var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy"); | |||
var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy"); | |||
Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second)); | |||
Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second)); | |||
var data_loader = new MnistModelLoader(); | |||
var num_epochs = 1; | |||
var batch_size = 8; | |||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
{ | |||
TrainDir = "mnist", | |||
OneHot = false, | |||
ValidationSize = 50000, | |||
}).Result; | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
} | |||
[TestMethod] | |||
public void AlexnetFromSequential() | |||
{ | |||
new SequentialModelSave().AlexnetFromSequential(); | |||
var model = keras.models.load_model(@"./alexnet_from_sequential"); | |||
model.summary(); | |||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||
var num_epochs = 1; | |||
var batch_size = 8; | |||
var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
} | |||
} |
@@ -1,27 +1,21 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using System.Diagnostics; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
using Tensorflow.Keras; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Operations; | |||
using System.Diagnostics; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
[TestClass] | |||
public class SequentialModelTest | |||
public class SequentialModelSave | |||
{ | |||
[TestMethod] | |||
public void SimpleModelFromAutoCompile() | |||
@@ -63,6 +57,8 @@ public class SequentialModelTest | |||
keras.layers.Softmax(1) | |||
}); | |||
model.summary(); | |||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
var data_loader = new MnistModelLoader(); | |||
@@ -82,7 +78,7 @@ public class SequentialModelTest | |||
} | |||
[TestMethod] | |||
public void AlexModelFromSequential() | |||
public void AlexnetFromSequential() | |||
{ | |||
Model model = KerasApi.keras.Sequential(new List<ILayer>() | |||
{ | |||
@@ -116,7 +112,7 @@ public class SequentialModelTest | |||
keras.layers.Softmax(1) | |||
}); | |||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||
var num_epochs = 1; | |||
var batch_size = 8; | |||
@@ -125,7 +121,7 @@ public class SequentialModelTest | |||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
model.save("./pb_alex_sequential", save_format: "tf"); | |||
model.save("./alexnet_from_sequential", save_format: "tf"); | |||
// The saved model can be test with the following python code: | |||
#region alexnet_python_code |
@@ -27,4 +27,28 @@ | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<None Update="Assets\simple_model_from_auto_compile\fingerprint.pb"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\keras_metadata.pb"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\saved_model.pb"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\variables\variables.data-00000-of-00001"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\variables\variables.index"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\kernel1.npy"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
<None Update="Assets\simple_model_from_auto_compile\bias0.npy"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
</ItemGroup> | |||
</Project> |