* Add CheckpointReader and corresponding C APIs. * Add essential components of SavedModel format loading. * Add checkpoint reading for SavedModel format loading. * Revise customized json converters. * Add support for loading models from python. * Fix the duplicated weights in Keras.Model. * Add alexnet loading test and check for loaded weights. * Fix ci error caused by branch merge. * Resolve the comments and errors. * Fix the stucking of training when loading model. * Fix the stucking of training when loading model. * fix intptr. --------- Co-authored-by: Haiping Chen <haiping008@gmail.com>tags/v0.100.4-load-saved-model
@@ -149,4 +149,22 @@ public static class CheckPointUtils | |||||
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | ||||
// } | // } | ||||
} | } | ||||
/// <summary> | |||||
/// Traverse the object graph and list all accessible objects. | |||||
/// </summary> | |||||
/// <param name="object_graph_view"></param> | |||||
public static IList<Trackable> list_objects(ObjectGraphView graph_view) | |||||
{ | |||||
return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||||
} | |||||
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||||
{ | |||||
return full_list.TakeWhile(x => | |||||
{ | |||||
var saveables = x.gather_saveables_for_checkpoint(); | |||||
return saveables is not null && saveables.Count > 0; | |||||
}); | |||||
} | |||||
} | } |
@@ -0,0 +1,100 @@ | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Checkpoint | |||||
{ | |||||
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle | |||||
{ | |||||
public SafeCheckpointReaderHandle(): base() | |||||
{ | |||||
} | |||||
public SafeCheckpointReaderHandle(IntPtr handle): base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TF_DeleteCheckpointReader(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
public class CheckpointReader | |||||
{ | |||||
private SafeCheckpointReaderHandle _handle; | |||||
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | |||||
public Dictionary<string, Shape> VariableToShapeMap { get; set; } | |||||
public CheckpointReader(string filename) | |||||
{ | |||||
Status status = new Status(); | |||||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||||
status.Check(true); | |||||
ReadAllShapeAndType(); | |||||
} | |||||
public int HasTensor(string name) | |||||
{ | |||||
return c_api.TF_CheckpointReaderHasTensor(_handle, name); | |||||
} | |||||
/// <summary> | |||||
/// Get the variable name. | |||||
/// </summary> | |||||
/// <param name="index"></param> | |||||
/// <returns></returns> | |||||
public string GetVariable(int index) | |||||
{ | |||||
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); | |||||
} | |||||
public int Size() | |||||
{ | |||||
return c_api.TF_CheckpointReaderSize(_handle); | |||||
} | |||||
public TF_DataType GetVariableDataType(string name) | |||||
{ | |||||
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); | |||||
} | |||||
public Shape GetVariableShape(string name) | |||||
{ | |||||
int num_dims = GetVariableNumDims(name); | |||||
long[] dims = new long[num_dims]; | |||||
Status status = new Status(); | |||||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||||
status.Check(true); | |||||
return new Shape(dims); | |||||
} | |||||
public int GetVariableNumDims(string name) | |||||
{ | |||||
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); | |||||
} | |||||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
Status status = new Status(); | |||||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||||
status.Check(true); | |||||
return new Tensor(tensor); | |||||
} | |||||
private void ReadAllShapeAndType() | |||||
{ | |||||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | |||||
VariableToShapeMap = new Dictionary<string, Shape>(); | |||||
int size = Size(); | |||||
for(int i = 0; i < size; i++) | |||||
{ | |||||
var name = GetVariable(i); | |||||
var shape = GetVariableShape(name); | |||||
var dtype = GetVariableDataType(name); | |||||
VariableToDataTypeMap[name] = dtype; | |||||
VariableToShapeMap[name] = shape; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -175,9 +175,9 @@ public static class SaveUtilV1 | |||||
{ | { | ||||
var name = factory_data.name; | var name = factory_data.name; | ||||
var key = factory_data.checkpoint_key; | var key = factory_data.checkpoint_key; | ||||
var maybe_saveable = factory_data.factory; | |||||
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); | |||||
// TODO: oneflow python has a process with callable `saveable_factory`. | |||||
// TODO: tensorflow python has a process with callable `saveable_factory`. | |||||
List<MySaveableObject> saveables = new(); | List<MySaveableObject> saveables = new(); | ||||
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | ||||
{ | { | ||||
@@ -217,7 +217,7 @@ public static class SaveUtilV1 | |||||
public record class CheckpointFactoryData | public record class CheckpointFactoryData | ||||
( | ( | ||||
Maybe<BaseResourceVariable, MySaveableObject> factory, | |||||
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
string name, | string name, | ||||
string checkpoint_key | string checkpoint_key | ||||
); | ); |
@@ -0,0 +1,27 @@ | |||||
using System.Runtime.InteropServices; | |||||
using Tensorflow.Checkpoint; | |||||
namespace Tensorflow | |||||
{ | |||||
public unsafe partial class c_api | |||||
{ | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern void TF_DeleteCheckpointReader(IntPtr reader); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); | |||||
[DllImport(TensorFlowLibName)] | |||||
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); | |||||
} | |||||
} |
@@ -6,8 +6,12 @@ using System.Linq; | |||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Exceptions; | |||||
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Operations; | |||||
using Newtonsoft.Json; | |||||
using Tensorflow.Training; | |||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
@@ -21,8 +25,20 @@ public class TrackableSaver | |||||
private TrackableObjectGraph _last_save_object_graph; | private TrackableObjectGraph _last_save_object_graph; | ||||
private Tensor? _object_graph_feed_tensor = null; | private Tensor? _object_graph_feed_tensor = null; | ||||
private Tensor? _file_prefix_feed_tensor = null; | private Tensor? _file_prefix_feed_tensor = null; | ||||
private Tensor? _file_prefix_placeholder = null; | |||||
private Dictionary<Trackable, Trackable>? _object_map = null; | private Dictionary<Trackable, Trackable>? _object_map = null; | ||||
private object? _cache = null; | private object? _cache = null; | ||||
public Tensor? FilePrefixPlaceHolder | |||||
{ | |||||
get | |||||
{ | |||||
return _file_prefix_placeholder; | |||||
} | |||||
set | |||||
{ | |||||
_file_prefix_placeholder = value; | |||||
} | |||||
} | |||||
public TrackableSaver(ObjectGraphView graph_view) | public TrackableSaver(ObjectGraphView graph_view) | ||||
{ | { | ||||
_graph_view = graph_view; | _graph_view = graph_view; | ||||
@@ -192,4 +208,366 @@ public class TrackableSaver | |||||
return save_path; | return save_path; | ||||
} | } | ||||
} | } | ||||
public LoadStatus restore(string? save_path, CheckpointOptions? options = null) | |||||
{ | |||||
if (options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
if(save_path is null) | |||||
{ | |||||
return new InitializationOnlyStatus(_graph_view, ops.uid()); | |||||
} | |||||
CheckpointReader reader = new CheckpointReader(save_path); | |||||
bool graph_building = tf.Context.executing_eagerly(); | |||||
Dictionary<string, TF_DataType> dtype_map = null; | |||||
if (!graph_building) | |||||
{ | |||||
dtype_map = reader.VariableToDataTypeMap; | |||||
} | |||||
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||||
Dictionary<Tensor, string> file_prefix_feed_dict; | |||||
Tensor file_prefix_tensor; | |||||
if (graph_building) | |||||
{ | |||||
if(_file_prefix_placeholder is null) | |||||
{ | |||||
tf.device("/cpu:0"); | |||||
_file_prefix_placeholder = constant_op.constant("model"); | |||||
} | |||||
file_prefix_tensor = _file_prefix_placeholder; | |||||
file_prefix_feed_dict = new(); | |||||
file_prefix_feed_dict[_file_prefix_placeholder] = save_path; | |||||
} | |||||
else | |||||
{ | |||||
tf.device("/cpu:0"); | |||||
file_prefix_tensor = constant_op.constant(save_path); | |||||
file_prefix_feed_dict = null; | |||||
} | |||||
TrackableObjectGraph object_graph_proto = new(); | |||||
if(object_graph_string.ndim > 0) | |||||
{ | |||||
object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||||
} | |||||
else | |||||
{ | |||||
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); | |||||
} | |||||
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||||
object_graph_proto: object_graph_proto, | |||||
save_path: save_path, | |||||
save_path_tensor: file_prefix_tensor, | |||||
reader: reader, | |||||
restore_op_cache: null, | |||||
graph_view: _graph_view, | |||||
options: options, | |||||
saveables_cache: null | |||||
); | |||||
new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); | |||||
if(_graph_view.AttachedDependencies is not null) | |||||
{ | |||||
foreach(var refer in _graph_view.AttachedDependencies) | |||||
{ | |||||
if(refer.Name == "root") | |||||
{ | |||||
continue; | |||||
} | |||||
int? proto_id = null; | |||||
// Find proto ID of attached dependency (if it is in the proto). | |||||
foreach (var proto_refer in object_graph_proto.Nodes[0].Children) | |||||
{ | |||||
if(proto_refer.LocalName == refer.Name) | |||||
{ | |||||
proto_id = proto_refer.NodeId; | |||||
break; | |||||
} | |||||
} | |||||
if (proto_id is null) | |||||
{ | |||||
continue; | |||||
} | |||||
// Object has already been restored. This can happen when there's an | |||||
// indirect connection from the attached object to the root. | |||||
if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) | |||||
{ | |||||
continue; | |||||
} | |||||
new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); | |||||
} | |||||
} | |||||
return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); | |||||
} | |||||
} | |||||
public class CheckpointRestoreCoordinator | |||||
{ | |||||
private CheckpointOptions _options; | |||||
private TrackableObjectGraph _object_graph_proto; | |||||
private int _restore_uid; | |||||
private HashSet<int> _matched_proto_ids; | |||||
private Tensor _save_path_tensor; | |||||
private string _save_path_string; | |||||
private CheckpointReader _reader; | |||||
private Dictionary<string, TF_DataType> _dtype_map; | |||||
private Dictionary<string, Shape> _shape_map; | |||||
private ObjectGraphView _graph_view; | |||||
private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations; | |||||
private bool _expect_partial_attr; | |||||
private List<Operation> _restore_ops; | |||||
private List<Trackable> _all_trackables; | |||||
private Dictionary<int, Trackable> _object_by_proto_id; | |||||
private Dictionary<string, Operation> _restore_ops_by_name; | |||||
private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations; | |||||
private Dictionary<int, IList<string>> _unused_attributes; | |||||
public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||||
CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||||
{ | |||||
// TODO(Rinne): cache. | |||||
_options = options; | |||||
_object_graph_proto = object_graph_proto; | |||||
_restore_uid = ops.uid(); | |||||
_save_path_tensor = save_path_tensor; | |||||
_save_path_string = save_path; | |||||
_reader = reader; | |||||
if(_reader is null) | |||||
{ | |||||
_reader = new CheckpointReader(save_path); | |||||
} | |||||
_dtype_map = _reader.VariableToDataTypeMap; | |||||
_shape_map = _reader.VariableToShapeMap; | |||||
_graph_view = graph_view; | |||||
_restore_ops = new List<Operation>(); | |||||
_restore_ops_by_name = new Dictionary<string, Operation>(); | |||||
_all_trackables = new List<Trackable>(); | |||||
_matched_proto_ids = new HashSet<int>(); | |||||
_object_by_proto_id = new Dictionary<int, Trackable>(); | |||||
_slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||||
_deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>(); | |||||
_expect_partial_attr = false; | |||||
for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||||
{ | |||||
var node = _object_graph_proto.Nodes[i]; | |||||
foreach(var slot_reference in node.SlotVariables) | |||||
{ | |||||
_slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>()) | |||||
.Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); | |||||
} | |||||
} | |||||
// skip the deleter and cache. | |||||
} | |||||
public bool ExpectPartial | |||||
{ | |||||
get | |||||
{ | |||||
return _expect_partial_attr; | |||||
} | |||||
set | |||||
{ | |||||
_expect_partial_attr = value; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Corresponding to `all_python_objects` of tensorflow python | |||||
/// </summary> | |||||
public List<Trackable> AllTrackables => _all_trackables; | |||||
public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||||
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||||
public int RestoreUid => _restore_uid; | |||||
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||||
public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations; | |||||
public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations; | |||||
public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name; | |||||
public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes; | |||||
public void new_restore_ops(IEnumerable<Operation> new_ops) | |||||
{ | |||||
_restore_ops.AddRange(new_ops); | |||||
// skip the callback. | |||||
} | |||||
public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
{ | |||||
List<Operation> restore_ops = new(); | |||||
foreach(var position in positions) | |||||
{ | |||||
var key = position.ObjectProto.Attributes[0].CheckpointKey; | |||||
throw new NotImplementedException(); | |||||
} | |||||
Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||||
foreach(var item in tensor_saveables) | |||||
{ | |||||
if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||||
{ | |||||
variable_dict[item.Key] = variable; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError(); | |||||
} | |||||
} | |||||
if (tensor_saveables is not null && tensor_saveables.Count > 0) | |||||
{ | |||||
var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); | |||||
var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
foreach(var item in new_restore_ops) | |||||
{ | |||||
restore_ops.Add(item.Value); | |||||
Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); | |||||
_restore_ops_by_name[item.Key] = item.Value; | |||||
} | |||||
} | |||||
} | |||||
return restore_ops; | |||||
} | |||||
} | |||||
public abstract class LoadStatus | |||||
{ | |||||
public abstract LoadStatus assert_consumed(); | |||||
public abstract LoadStatus assert_existing_objects_matched(); | |||||
public abstract LoadStatus assert_nontrivial_match(); | |||||
public abstract LoadStatus run_restore_ops(Session? session = null); | |||||
public abstract void initialize_or_restore(Session? session = null); | |||||
public virtual LoadStatus expect_partial() | |||||
{ | |||||
return this; | |||||
} | |||||
} | |||||
public class InitializationOnlyStatus: LoadStatus | |||||
{ | |||||
private int _restore_uid; | |||||
private ObjectGraphView _object_graph_view; | |||||
private Trackable _root; | |||||
public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) | |||||
{ | |||||
_restore_uid = restore_uid; | |||||
_object_graph_view = object_graph_view; | |||||
_root = object_graph_view.Root; | |||||
} | |||||
public override LoadStatus assert_consumed() | |||||
{ | |||||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
} | |||||
public override LoadStatus assert_existing_objects_matched() | |||||
{ | |||||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
} | |||||
public override LoadStatus assert_nontrivial_match() | |||||
{ | |||||
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
} | |||||
public override LoadStatus run_restore_ops(Session? session = null) | |||||
{ | |||||
throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||||
+ "(save_path=None to Saver.restore)."); | |||||
} | |||||
public override void initialize_or_restore(Session? session = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
return; | |||||
} | |||||
if(session is null) | |||||
{ | |||||
session = new Session(); | |||||
} | |||||
var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
} | |||||
internal class CheckpointLoadStatus: LoadStatus | |||||
{ | |||||
private CheckpointRestoreCoordinator _checkpoint; | |||||
private Dictionary<Tensor, string> _feed_dict; | |||||
private ObjectGraphView _object_graph_view; | |||||
private Trackable _root; | |||||
public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base() | |||||
{ | |||||
_checkpoint = checkpoint; | |||||
_feed_dict = feed_dict; | |||||
_object_graph_view = graph_view; | |||||
_root = graph_view.Root; | |||||
} | |||||
public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||||
public override LoadStatus assert_consumed() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public override LoadStatus assert_existing_objects_matched() | |||||
{ | |||||
for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) | |||||
{ | |||||
var node = _checkpoint.ObjectGraphProto.Nodes[i]; | |||||
if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && | |||||
trackable.UpdateUid < _checkpoint.RestoreUid) | |||||
{ | |||||
throw new AssertionError($"Object {node} not assigned a value from checkpoint."); | |||||
} | |||||
} | |||||
foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) | |||||
{ | |||||
if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) | |||||
{ | |||||
continue; | |||||
} | |||||
_checkpoint.AllTrackables.Add(trackable_object); | |||||
} | |||||
var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) | |||||
.Except(_checkpoint.ObjectByProtoId.Values); | |||||
if (unused_trackables.Any()) | |||||
{ | |||||
var num_unused_trackables = unused_trackables.Count(); | |||||
var num_variables_to_show = Math.Min(10, num_unused_trackables); | |||||
throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + | |||||
$"not bound to checkpointed values, likely due to changes in the " + | |||||
$"Python program. Showing {num_variables_to_show} of " + | |||||
$"{num_unused_trackables} unmatched objects: " + | |||||
$"{{list(unused_python_objects)[:num_variables_to_show]}}"); | |||||
} | |||||
return this; | |||||
} | |||||
public override LoadStatus assert_nontrivial_match() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public override LoadStatus expect_partial() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public override void initialize_or_restore(Session? session = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public override LoadStatus run_restore_ops(Session? session = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | } |
@@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint | |||||
// tf python has code `with ops.device(restore_device):` here. | // tf python has code `with ops.device(restore_device):` here. | ||||
tf.device(restore_device); // may be risky. | tf.device(restore_device); // may be risky. | ||||
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | ||||
int idx = 0; | int idx = 0; | ||||
@@ -0,0 +1,331 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Checkpoint; | |||||
public class CheckpointPosition | |||||
{ | |||||
private CheckpointRestoreCoordinator _checkpoint; | |||||
private int _proto_id; | |||||
private bool _skip_restore; | |||||
public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) | |||||
{ | |||||
_checkpoint = checkpoint; | |||||
_proto_id = proto_id; | |||||
_skip_restore = false; | |||||
} | |||||
public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||||
public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||||
public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; | |||||
public void restore(Trackable trackable) | |||||
{ | |||||
using (ops.init_scope()) | |||||
{ | |||||
if (bind_project(trackable)) | |||||
{ | |||||
var restore_ops = _restore_descendants(); | |||||
if(restore_ops is not null && restore_ops.Count > 0) | |||||
{ | |||||
_checkpoint.new_restore_ops(restore_ops); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Set a checkpoint<->object correspondence. | |||||
/// </summary> | |||||
/// <param name="trackable"></param> | |||||
/// <returns></returns> | |||||
public bool bind_project(Trackable trackable) | |||||
{ | |||||
_checkpoint.AllTrackables.Add(trackable); | |||||
_checkpoint.MatchedProtoIds.Add(_proto_id); | |||||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||||
{ | |||||
// skip the `logging.warning`. | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
_checkpoint.ObjectByProtoId[_proto_id] = trackable; | |||||
return true; | |||||
} | |||||
} | |||||
public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
{ | |||||
// skip the registered_saver | |||||
if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||||
{ | |||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
new List<CheckpointPosition>(), null); | |||||
} | |||||
var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); | |||||
List<Operation> existing_restore_ops; | |||||
List<CheckpointPosition> positions = new(); | |||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||||
{ | |||||
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||||
} | |||||
else if(saveable_factories.Count > 0) | |||||
{ | |||||
(existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return (existing_restore_ops, named_saveables, positions, null); | |||||
} | |||||
public CheckpointPosition create_child_position(int node_id) | |||||
{ | |||||
return new CheckpointPosition(_checkpoint, node_id); | |||||
} | |||||
public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, | |||||
int slot_variable_id, string slot_name) | |||||
{ | |||||
//CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); | |||||
// TODO(Rinne): implement it. | |||||
return (null, null); | |||||
} | |||||
/// <summary> | |||||
/// Creates a saveable using the _serialize_to_tensor method. | |||||
/// </summary> | |||||
/// <param name="saveable_factories"></param> | |||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
{ | |||||
string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||||
suffix = suffix ?? ""; | |||||
var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
throw new NotImplementedException("The restore under graph mode has not been implemented. " + | |||||
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||||
// skip the cache. | |||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
dict[saveable_name] = saveable; | |||||
return (new List<Operation>(), dict); | |||||
} | |||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
if(ObjectProto.Attributes is null) | |||||
{ | |||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||||
} | |||||
List<Operation> existing_restore_ops = new(); | |||||
HashSet<string> created_compat_names = new(); | |||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
foreach (var serialized_tensor in ObjectProto.Attributes) | |||||
{ | |||||
Operation existing_op; | |||||
if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) | |||||
{ | |||||
existing_op = null; | |||||
} | |||||
else | |||||
{ | |||||
existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; | |||||
} | |||||
if(existing_op is not null) | |||||
{ | |||||
existing_restore_ops.Add(existing_op); | |||||
continue; | |||||
} | |||||
if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) | |||||
{ | |||||
continue; | |||||
} | |||||
// TODO(Rinne): deal with cache. | |||||
var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); | |||||
if(saveable is null) | |||||
{ | |||||
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||||
continue; | |||||
} | |||||
named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||||
} | |||||
return (existing_restore_ops, named_saveables); | |||||
} | |||||
private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||||
{ | |||||
var expected_factory_name = serialized_tensor.Name; | |||||
var factory_input_name = serialized_tensor.CheckpointKey; | |||||
if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) | |||||
{ | |||||
foreach(var item in saveable_factories) | |||||
{ | |||||
var factory_name = item.Key; | |||||
var factory = item.Value; | |||||
if (expected_factory_name.StartsWith(factory_name)) | |||||
{ | |||||
if(matched_factory is not null) | |||||
{ | |||||
throw new ValueError($"Forward compatibility load error: Unable to load " + | |||||
"checkpoint saved in future version of TensorFlow. " + | |||||
"Please update your version of TensorFlow to the " + | |||||
"version in which the checkpoint was saved."); | |||||
} | |||||
} | |||||
matched_factory = factory; | |||||
factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; | |||||
created_compat_names.Add(factory_name); | |||||
} | |||||
} | |||||
return matched_factory(factory_input_name); | |||||
} | |||||
private string _extract_saveable_name(string checkpoint_key) | |||||
{ | |||||
var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; | |||||
return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); | |||||
} | |||||
/// <summary> | |||||
/// Restore the bound Trackable and dependencies (may be deferred). | |||||
/// </summary> | |||||
private List<Operation> _restore_descendants() | |||||
{ | |||||
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||||
visit_queue.Enqueue((this, this.Trackable)); | |||||
List<Operation> restore_ops = new(); | |||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
List<CheckpointPosition> positions = new(); | |||||
CheckpointPosition current_position = null; | |||||
while (visit_queue.Count > 0) | |||||
{ | |||||
current_position = visit_queue.Dequeue().Item1; | |||||
var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); | |||||
restore_ops.AddRange(new_restore_ops); | |||||
foreach(var item in new_tensor_saveables) | |||||
{ | |||||
tensor_saveables.Add(item.Key, item.Value); | |||||
} | |||||
positions.AddRange(new_positions); | |||||
_queue_children_for_restoration(current_position, visit_queue); | |||||
_queue_slot_variables(current_position, visit_queue); | |||||
} | |||||
restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); | |||||
return restore_ops; | |||||
} | |||||
private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||||
{ | |||||
var trackable = checkpoint_position.Trackable; | |||||
foreach(var child in checkpoint_position.ObjectProto.Children) | |||||
{ | |||||
var child_position = checkpoint_position.create_child_position(child.NodeId); | |||||
var local_object = trackable._lookup_dependency(child.LocalName); | |||||
var child_proto = child_position.ObjectProto; | |||||
if(local_object is null) | |||||
{ | |||||
if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) | |||||
{ | |||||
trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
if (child_position.bind_project(local_object)) | |||||
{ | |||||
visit_queue.Enqueue((child_position, local_object)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||||
{ | |||||
var trackable = checkpoint_position.Trackable; | |||||
var checkpoint = checkpoint_position.Checkpoint; | |||||
if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) | |||||
{ | |||||
checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); | |||||
foreach (var deferred_slot_restoration in positions) | |||||
{ | |||||
var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( | |||||
trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, | |||||
deferred_slot_restoration.SlotName | |||||
); | |||||
if(slot_variable_position is not null) | |||||
{ | |||||
visit_queue.Enqueue((slot_variable_position, slot_variable)); | |||||
} | |||||
} | |||||
} | |||||
if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) | |||||
{ | |||||
checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); | |||||
foreach (var slot_restoration in restorations) | |||||
{ | |||||
if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
// TODO(Rinne); implement it. | |||||
} | |||||
else | |||||
{ | |||||
Debug.Assert(trackable is BaseResourceVariable); | |||||
Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>()) | |||||
.Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
{ | |||||
var trackable = this.Trackable; | |||||
trackable._maybe_initialize_trackable(); | |||||
if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||||
{ | |||||
var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); | |||||
trackable.UpdateUid = _checkpoint.RestoreUid; | |||||
return (restore_ops, tensor_saveables, positions, registered_savers); | |||||
} | |||||
else | |||||
{ | |||||
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
new List<CheckpointPosition>(), null); | |||||
} | |||||
} | |||||
} | |||||
public record class DeferredSlotVariableRestoration( | |||||
BaseResourceVariable OriginalVariable, | |||||
int SlotVariableId, | |||||
string SlotName | |||||
); |
@@ -10,7 +10,7 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
internal class execute | |||||
internal static class execute | |||||
{ | { | ||||
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | ||||
{ | { | ||||
@@ -27,5 +27,9 @@ namespace Tensorflow.Eager | |||||
return tensors; | return tensors; | ||||
} | } | ||||
public static bool must_record_gradient() | |||||
{ | |||||
return false; | |||||
} | |||||
} | } | ||||
} | } |
@@ -13,8 +13,8 @@ namespace Tensorflow.Functions | |||||
/// </summary> | /// </summary> | ||||
public class ConcreteFunction: Trackable | public class ConcreteFunction: Trackable | ||||
{ | { | ||||
FuncGraph func_graph; | |||||
ForwardBackwardCall forward_backward; | |||||
internal FuncGraph func_graph; | |||||
internal ForwardBackwardCall forward_backward; | |||||
public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
@@ -23,6 +23,8 @@ namespace Tensorflow.Functions | |||||
public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
public Type ReturnType; | public Type ReturnType; | ||||
public TensorSpec[] OutputStructure; | public TensorSpec[] OutputStructure; | ||||
public IEnumerable<string> ArgKeywords { get; set; } | |||||
public long NumPositionArgs { get; set; } | |||||
public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
{ | { | ||||
@@ -163,6 +165,15 @@ namespace Tensorflow.Functions | |||||
return flat_outputs; | return flat_outputs; | ||||
} | } | ||||
public void AddTograph(Graph? g = null) | |||||
{ | |||||
if(!tf.Context.executing_eagerly() && g is null) | |||||
{ | |||||
g = ops.get_default_graph(); | |||||
} | |||||
// TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||||
} | |||||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
{ | { | ||||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | ||||
@@ -16,8 +16,10 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.IO | namespace Tensorflow.IO | ||||
{ | { | ||||
@@ -63,5 +65,15 @@ namespace Tensorflow.IO | |||||
dirs.AddRange(Directory.GetFiles(dir)); | dirs.AddRange(Directory.GetFiles(dir)); | ||||
return dirs.ToArray(); | return dirs.ToArray(); | ||||
} | } | ||||
public string join(params string[] paths) | |||||
{ | |||||
Debug.Assert(paths.Length >= 1); | |||||
if (paths[0].Substring(1).Contains("://")) | |||||
{ | |||||
throw new NotImplementedException("The combination of urls has not been implemented."); | |||||
} | |||||
return Path.Combine(paths); | |||||
} | |||||
} | } | ||||
} | } |
@@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
{ | { | ||||
var axis = serializer.Deserialize(reader, typeof(long[])); | |||||
int[]? axis; | |||||
if(reader.ValueType == typeof(long)) | |||||
{ | |||||
axis = new int[1]; | |||||
axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||||
} | |||||
else | |||||
{ | |||||
axis = serializer.Deserialize(reader, typeof(int[])) as int[]; | |||||
} | |||||
if (axis is null) | if (axis is null) | ||||
{ | { | ||||
throw new ValueError("Cannot deserialize 'null' to `Axis`."); | throw new ValueError("Cannot deserialize 'null' to `Axis`."); | ||||
@@ -0,0 +1,36 @@ | |||||
using Newtonsoft.Json.Linq; | |||||
using Newtonsoft.Json; | |||||
namespace Tensorflow.Keras.Common | |||||
{ | |||||
public class CustomizedDTypeJsonConverter : JsonConverter | |||||
{ | |||||
public override bool CanConvert(Type objectType) | |||||
{ | |||||
return objectType == typeof(TF_DataType); | |||||
} | |||||
public override bool CanRead => true; | |||||
public override bool CanWrite => true; | |||||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
{ | |||||
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); | |||||
token.WriteTo(writer); | |||||
} | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
{ | |||||
if (reader.ValueType == typeof(string)) | |||||
{ | |||||
var str = (string)serializer.Deserialize(reader, typeof(string)); | |||||
return dtypes.tf_dtype_from_name(str); | |||||
} | |||||
else | |||||
{ | |||||
return (TF_DataType)serializer.Deserialize(reader, typeof(int)); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common | |||||
{ | { | ||||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | throw new ValueError("Cannot deserialize 'null' to `Shape`."); | ||||
} | } | ||||
if(values.Length != 3) | |||||
if(values.Length == 1) | |||||
{ | |||||
var array = values[0] as JArray; | |||||
if(array is null) | |||||
{ | |||||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||||
} | |||||
values = array.ToObject<object[]>(); | |||||
} | |||||
if (values.Length < 3) | |||||
{ | { | ||||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | ||||
} | } | ||||
@@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common | |||||
{ | { | ||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | ||||
} | } | ||||
if (values[1] is not int) | |||||
int nodeIndex; | |||||
int tensorIndex; | |||||
if (values[1] is long) | |||||
{ | |||||
nodeIndex = (int)(long)values[1]; | |||||
} | |||||
else if (values[1] is int) | |||||
{ | |||||
nodeIndex = (int)values[1]; | |||||
} | |||||
else | |||||
{ | { | ||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | ||||
} | } | ||||
if (values[2] is not int) | |||||
if (values[2] is long) | |||||
{ | |||||
tensorIndex = (int)(long)values[2]; | |||||
} | |||||
else if (values[1] is int) | |||||
{ | |||||
tensorIndex = (int)values[2]; | |||||
} | |||||
else | |||||
{ | { | ||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | ||||
} | } | ||||
return new NodeConfig() | return new NodeConfig() | ||||
{ | { | ||||
Name = values[0] as string, | Name = values[0] as string, | ||||
NodeIndex = (int)values[1], | |||||
TensorIndex = (int)values[2] | |||||
NodeIndex = nodeIndex, | |||||
TensorIndex = tensorIndex | |||||
}; | }; | ||||
} | } | ||||
} | } | ||||
@@ -51,10 +51,28 @@ namespace Tensorflow.Keras.Common | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
{ | { | ||||
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
if(dims is null) | |||||
long?[] dims; | |||||
try | |||||
{ | { | ||||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
catch (JsonSerializationException ex) | |||||
{ | |||||
if (reader.Value.Equals("class_name")) | |||||
{ | |||||
reader.Read(); | |||||
reader.Read(); | |||||
reader.Read(); | |||||
dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
} | |||||
else | |||||
{ | |||||
throw ex; | |||||
} | |||||
} | |||||
if (dims is null) | |||||
{ | |||||
return null; | |||||
} | } | ||||
long[] convertedDims = new long[dims.Length]; | long[] convertedDims = new long[dims.Length]; | ||||
for(int i = 0; i < dims.Length; i++) | for(int i = 0; i < dims.Length; i++) | ||||
@@ -19,6 +19,7 @@ namespace Tensorflow.Keras | |||||
List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
List<IVariableV1> Weights { get; } | |||||
Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
TensorShapeConfig BuildInputShape { get; } | TensorShapeConfig BuildInputShape { get; } | ||||
@@ -1,8 +1,11 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using Newtonsoft.Json.Linq; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow.ModelSaving | namespace Tensorflow.ModelSaving | ||||
{ | { | ||||
@@ -71,6 +71,7 @@ namespace Tensorflow | |||||
public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | ||||
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | ||||
public List<IVariableV1> Weights => throw new NotImplementedException(); | |||||
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | ||||
public Shape OutputShape => throw new NotImplementedException(); | public Shape OutputShape => throw new NotImplementedException(); | ||||
@@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations | |||||
/// | /// | ||||
/// Callers must ensure all the named tensors are indeed stored in the checkpoint. | /// Callers must ensure all the named tensors are indeed stored in the checkpoint. | ||||
/// </remarks> | /// </remarks> | ||||
public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||||
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
Dictionary<string, object> attrs = new(); | |||||
attrs["dtypes"] = dtypes; | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||||
"RestoreV2", name, prefix, tensor_names, shape_and_slices | |||||
) | |||||
{ attrs = attrs }); | |||||
return result; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
try | |||||
{ | |||||
return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["prefix"] = prefix; | dict["prefix"] = prefix; | ||||
dict["tensor_names"] = tensor_names; | dict["tensor_names"] = tensor_names; | ||||
@@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations | |||||
return (tensors); | return (tensors); | ||||
} | } | ||||
public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) | |||||
{ | |||||
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||||
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||||
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||||
object[] attrs = new object[] { "dtypes", dtypes }; | |||||
Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||||
var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
// TODO(Rinne); record the gradient | |||||
} | |||||
return result; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Reverses specific dimensions of a tensor. | /// Reverses specific dimensions of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -62,6 +62,7 @@ namespace Tensorflow | |||||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | ||||
{ | { | ||||
// Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. | |||||
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
return _op.outputs; | return _op.outputs; | ||||
@@ -17,8 +17,8 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training.Saving.SavedModel; | |||||
using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
@@ -1,9 +1,13 @@ | |||||
namespace Tensorflow | |||||
using Newtonsoft.Json; | |||||
using Tensorflow.Keras.Common; | |||||
namespace Tensorflow | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | ||||
/// The enum values here are identical to corresponding values in types.proto. | /// The enum values here are identical to corresponding values in types.proto. | ||||
/// </summary> | /// </summary> | ||||
[JsonConverter(typeof(CustomizedDTypeJsonConverter))] | |||||
public enum TF_DataType | public enum TF_DataType | ||||
{ | { | ||||
DtInvalid = 0, | DtInvalid = 0, | ||||
@@ -159,7 +159,10 @@ namespace Tensorflow | |||||
"uint32" => TF_DataType.TF_UINT32, | "uint32" => TF_DataType.TF_UINT32, | ||||
"int64" => TF_DataType.TF_INT64, | "int64" => TF_DataType.TF_INT64, | ||||
"uint64" => TF_DataType.TF_UINT64, | "uint64" => TF_DataType.TF_UINT64, | ||||
"float16" => TF_DataType.TF_BFLOAT16, | |||||
"float32" => TF_DataType.TF_FLOAT, | |||||
"single" => TF_DataType.TF_FLOAT, | "single" => TF_DataType.TF_FLOAT, | ||||
"float64" => TF_DataType.TF_DOUBLE, | |||||
"double" => TF_DataType.TF_DOUBLE, | "double" => TF_DataType.TF_DOUBLE, | ||||
"complex" => TF_DataType.TF_COMPLEX128, | "complex" => TF_DataType.TF_COMPLEX128, | ||||
"string" => TF_DataType.TF_STRING, | "string" => TF_DataType.TF_STRING, | ||||
@@ -39,6 +39,24 @@ namespace Tensorflow | |||||
_op = value; | _op = value; | ||||
} | } | ||||
} | } | ||||
public BaseResourceVariable variable | |||||
{ | |||||
get | |||||
{ | |||||
if (_op.TryGet<BaseResourceVariable>(out var v)) | |||||
{ | |||||
return v; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("The _op is not a variable."); | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
_op = value; | |||||
} | |||||
} | |||||
public SaveSpec[] specs; | public SaveSpec[] specs; | ||||
public string name; | public string name; | ||||
public string device; | public string device; | ||||
@@ -0,0 +1,23 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public record class LoadOptions | |||||
{ | |||||
public bool allow_partial_checkpoint; | |||||
public string experimental_io_device; | |||||
public bool experimental_skip_checkpoint; | |||||
public VariablePolicy experimental_variable_policy; | |||||
public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, | |||||
bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) | |||||
{ | |||||
this.allow_partial_checkpoint = allow_partial_checkpoint; | |||||
this.experimental_io_device = experimental_io_device; | |||||
this.experimental_skip_checkpoint = experimental_skip_checkpoint; | |||||
this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); | |||||
} | |||||
} | |||||
} |
@@ -1,4 +1,5 @@ | |||||
using Tensorflow.Train; | |||||
using System; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow; | namespace Tensorflow; | ||||
@@ -14,4 +15,10 @@ public class RevivedTypes | |||||
// TODO: complete the implementation. | // TODO: complete the implementation. | ||||
return null; | return null; | ||||
} | } | ||||
public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||||
{ | |||||
// TODO: complete the implementation. | |||||
return null; | |||||
} | |||||
} | } |
@@ -2,7 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.ModelSaving | |||||
namespace Tensorflow | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Options for saving to SavedModel. | /// Options for saving to SavedModel. | ||||
@@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving | |||||
public bool save_variable_devices() | public bool save_variable_devices() | ||||
{ | { | ||||
return this != VariablePolicy.None; | |||||
return this != None; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static VariablePolicy from_obj(object obj) | public static VariablePolicy from_obj(object obj) | ||||
{ | { | ||||
if (obj is null) return VariablePolicy.None; | |||||
if (obj is null) return None; | |||||
if (obj is VariablePolicy) return (VariablePolicy)obj; | if (obj is VariablePolicy) return (VariablePolicy)obj; | ||||
var key = obj.ToString().ToLower(); | var key = obj.ToString().ToLower(); | ||||
return key switch | return key switch | ||||
{ | { | ||||
null => VariablePolicy.None, | |||||
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||||
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||||
null => None, | |||||
"save_variable_devices" => SAVE_VARIABLE_DEVICES, | |||||
"expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, | |||||
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | ||||
}; | }; | ||||
} | } |
@@ -5,7 +5,6 @@ using System.Linq; | |||||
using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Functions; | |||||
namespace Tensorflow.Training.Saving.SavedModel | |||||
{ | |||||
/// <summary> | |||||
/// A class wraps a concrete function to handle different distributed contexts. | |||||
/// </summary> | |||||
internal class WrapperFunction: ConcreteFunction | |||||
{ | |||||
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||||
{ | |||||
this.forward_backward = concrete_function.forward_backward; | |||||
this.Outputs = concrete_function.Outputs; | |||||
this.ReturnType = concrete_function.ReturnType; | |||||
this.OutputStructure = concrete_function.OutputStructure; | |||||
this.ArgKeywords = concrete_function.ArgKeywords; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,36 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Training.Saving.SavedModel | |||||
{ | |||||
public static class function_deserialization | |||||
{ | |||||
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||||
IDictionary<string, ConcreteFunction> concrete_functions) | |||||
{ | |||||
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; | |||||
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||||
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||||
var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||||
concrete_function.AddTograph(); | |||||
return concrete_function; | |||||
} | |||||
private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) | |||||
{ | |||||
// TODO(Rinne); revise the implementation. | |||||
return new FunctionSpec() | |||||
{ | |||||
Fullargspec = function_spec_proto.Fullargspec, | |||||
IsMethod = function_spec_proto.IsMethod, | |||||
InputSignature = function_spec_proto.InputSignature, | |||||
JitCompile = function_spec_proto.JitCompile | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,641 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using System.Net.Sockets; | |||||
using System.Text; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using static Tensorflow.Binding; | |||||
using System.Runtime.CompilerServices; | |||||
using Tensorflow.Variables; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Helper class to load an object-based SavedModel. | |||||
/// </summary> | |||||
public partial class Loader | |||||
{ | |||||
private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def; | |||||
private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes; | |||||
private SavedObjectGraph _proto; | |||||
private string _export_dir; | |||||
private CheckpointOptions _checkpoint_options; | |||||
private LoadOptions _save_options; | |||||
private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters; | |||||
private Dictionary<string, int>? _node_path_to_id; | |||||
private List<int>? _filtered_nodes; | |||||
private List<int> _ordered_node_ids; | |||||
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||||
private List<Trackable> _nodes; | |||||
private Dictionary<int, Action<object, object, object>> _node_setters; | |||||
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||||
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||||
{ | |||||
var meta_graph = saved_model_proto.MetaGraphs[0]; | |||||
_asset_file_def = meta_graph.AssetFileDef; | |||||
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||||
_proto = object_graph_proto; | |||||
_export_dir = export_dir; | |||||
// TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||||
_checkpoint_options = ckpt_options; | |||||
_save_options = save_options; | |||||
// TODO: `this._pretty_printer` | |||||
_node_filters = filters; | |||||
_node_path_to_id = _convert_node_paths_to_ints(); | |||||
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||||
foreach(var filter in filters) | |||||
{ | |||||
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; | |||||
} | |||||
_filtered_nodes = _retrieve_all_filtered_nodes(); | |||||
_ordered_node_ids = _generate_ordered_node_ids(); | |||||
_load_all(); | |||||
if (!save_options.experimental_skip_checkpoint) | |||||
{ | |||||
_restore_checkpoint(); | |||||
} | |||||
foreach(var node in _nodes) | |||||
{ | |||||
// skip the process of `CapturableResource`. | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Maps all string node paths in node_filters to the int node ids. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
private Dictionary<string, int>? _convert_node_paths_to_ints() | |||||
{ | |||||
if( _node_filters is null) | |||||
{ | |||||
return null; | |||||
} | |||||
Dictionary<string, int> path_to_int = new(); | |||||
foreach(var node_id in _node_filters.Keys) | |||||
{ | |||||
int int_node_id; | |||||
var node_path = node_id.Split('.'); | |||||
if (node_path[0] != "root") | |||||
{ | |||||
throw new ValueError($"When passing string identifiers to node_filters, the first name" + | |||||
$" must be root. Received {node_path[0]}."); | |||||
} | |||||
int_node_id = 0; | |||||
for(int i = 0; i < node_path.Length - 1; i++) | |||||
{ | |||||
var name = node_path[i + 1]; | |||||
int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); | |||||
} | |||||
path_to_int[node_id] = int_node_id; | |||||
} | |||||
return path_to_int; | |||||
} | |||||
private int _find_node_child(int node_id, string child_name, string path) | |||||
{ | |||||
foreach(var refer in _proto.Nodes[node_id].Children) | |||||
{ | |||||
if(refer.LocalName == child_name) | |||||
{ | |||||
return refer.NodeId; | |||||
} | |||||
} | |||||
throw new ValueError($"Unable to find node {path}."); | |||||
} | |||||
private List<int>? _retrieve_all_filtered_nodes() | |||||
{ | |||||
if(_node_filters is null) | |||||
{ | |||||
return null; | |||||
} | |||||
HashSet<int> all_filtered_nodes = new(); | |||||
Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys); | |||||
while(nodes_to_visit.Count > 0) | |||||
{ | |||||
var node_path = nodes_to_visit.Dequeue(); | |||||
var node_id = _node_path_to_id[node_path]; | |||||
if (all_filtered_nodes.Contains(node_id)) | |||||
{ | |||||
continue; | |||||
} | |||||
all_filtered_nodes.Add(node_id); | |||||
Trackable node = null; | |||||
Action<object, object, object> setter = null; | |||||
if(_loaded_nodes.TryGetValue(node_id, out var res)) | |||||
{ | |||||
(node, setter) = res; | |||||
} | |||||
if(node is not null) | |||||
{ | |||||
node._maybe_initialize_trackable(); | |||||
} | |||||
foreach(var refer in _proto.Nodes[node_id].Children) | |||||
{ | |||||
Trackable children_object = null; | |||||
if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) | |||||
{ | |||||
children_object = result.Item1; | |||||
} | |||||
// See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. | |||||
if(children_object is null && node is not null) | |||||
{ | |||||
children_object = node._lookup_dependency(refer.LocalName); | |||||
if(children_object is TrackableDataStructure) | |||||
{ | |||||
// TODO: set setter as lambda. | |||||
_loaded_nodes[refer.NodeId] = (children_object, setter); | |||||
} | |||||
} | |||||
string child_path = $"{node_path}.{refer.LocalName}"; | |||||
_node_path_to_id[child_path] = refer.NodeId; | |||||
nodes_to_visit.Enqueue(child_path); | |||||
} | |||||
} | |||||
if (all_filtered_nodes.Contains(0)) | |||||
{ | |||||
return null; | |||||
} | |||||
return all_filtered_nodes.ToList(); | |||||
} | |||||
/// <summary> | |||||
/// Orders the node ids so that dependencies appear first. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
private List<int> _generate_ordered_node_ids() | |||||
{ | |||||
List<int> unordered_ids; | |||||
if(_filtered_nodes is null) | |||||
{ | |||||
unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); | |||||
} | |||||
else | |||||
{ | |||||
unordered_ids = new List<int>(_filtered_nodes); | |||||
} | |||||
Dictionary<int, List<int>> dependency_map = new(); | |||||
foreach(var node_id in unordered_ids) | |||||
{ | |||||
var deps = dependency_map.SetDefault(node_id, new List<int>()); | |||||
if (_loaded_nodes.ContainsKey(node_id)) | |||||
{ | |||||
continue; | |||||
} | |||||
var proto = _proto.Nodes[node_id]; | |||||
foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||||
{ | |||||
deps.Add(dep); | |||||
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||||
{ | |||||
// TODO: add info with `_pretty_printer`. | |||||
throw new ValueError($"Unable to partially load SavedModel since the specified filter " + | |||||
$"does not include all required objects for loading (e.g. " + | |||||
$"variables used in functions or deserialization dependencies). " + | |||||
$"Please include this path in the filter: {dep}"); | |||||
} | |||||
} | |||||
int? prev_slot = null; | |||||
foreach(var slot_variable_proto in proto.SlotVariables) | |||||
{ | |||||
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||||
// The optimizer and original variable must be created before the slot | |||||
// variable, since the slot variable is generated using the Optimizer's | |||||
// add_slot API. | |||||
var slot_deps = dependency_map[slot_variable_node_id]; | |||||
slot_deps.Add(node_id); | |||||
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||||
if(prev_slot is not null) | |||||
{ | |||||
slot_deps.Add(prev_slot.Value); | |||||
} | |||||
prev_slot = slot_variable_node_id; | |||||
} | |||||
} | |||||
try | |||||
{ | |||||
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||||
} | |||||
catch (TrackableUtils.CyclicDependencyError ex) | |||||
{ | |||||
throw new ValueError("Encountered a cycle in the deserialization dependencies" + | |||||
"in the SavedModel. This is extremely unexpected, please" + | |||||
"file a bug and make sure you are not manually modifying the SavedModel."); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Returns a dictionary of all dependencies of an object. | |||||
/// </summary> | |||||
/// <param name="proto"></param> | |||||
/// <returns></returns> | |||||
private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||||
{ | |||||
Dictionary<Maybe<string, int>, int> dependencies = new(); | |||||
foreach(var refer in proto.Dependencies) | |||||
{ | |||||
dependencies[refer.LocalName] = refer.NodeId; | |||||
} | |||||
if(proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
{ | |||||
var concreete_functions = proto.Function.ConcreteFunctions; | |||||
foreach(var fn_name in concreete_functions) | |||||
{ | |||||
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||||
{ | |||||
dependencies[bound_input] = bound_input; | |||||
} | |||||
} | |||||
} | |||||
else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) | |||||
{ | |||||
var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; | |||||
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||||
{ | |||||
dependencies[bound_input] = bound_input; | |||||
} | |||||
} | |||||
else if(proto.KindCase == SavedObject.KindOneofCase.Resource) | |||||
{ | |||||
foreach(var child in proto.Children) | |||||
{ | |||||
if(child.LocalName == "_create_resource") | |||||
{ | |||||
dependencies["_create_resource"] = child.NodeId; | |||||
} | |||||
} | |||||
} | |||||
return dependencies; | |||||
} | |||||
/// <summary> | |||||
/// Loads all nodes and functions from the SavedModel and their edges. | |||||
/// </summary> | |||||
private void _load_all() | |||||
{ | |||||
_load_nodes(); | |||||
_load_edges(); | |||||
_setup_remaining_functions(); | |||||
_load_checkpoint_save_and_restore_functions(); | |||||
} | |||||
/// <summary> | |||||
/// Restores the checkpoint-related save/restore functions to all nodes. | |||||
/// </summary> | |||||
private void _load_checkpoint_save_and_restore_functions() | |||||
{ | |||||
foreach(var (node_id, proto) in _iter_all_nodes()) | |||||
{ | |||||
var node = get(node_id); | |||||
if(node is null) | |||||
{ | |||||
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
continue; | |||||
} | |||||
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||||
{ | |||||
// Restore Trackable serialize- and restore-from-tensor functions. | |||||
Debug.Assert(proto.SaveableObjects.Count == 1); | |||||
var saveable_object_proto = proto.SaveableObjects.Values.First(); | |||||
var save_fn_id = saveable_object_proto.SaveFunction; | |||||
var restore_fn_id = saveable_object_proto.RestoreFunction; | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
else | |||||
{ | |||||
// Restore legacy SaveableObject functions. | |||||
Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new(); | |||||
foreach(var item in proto.SaveableObjects) | |||||
{ | |||||
var name = item.Key; | |||||
var saveable_object_proto = item.Value; | |||||
var save_fn_id = saveable_object_proto.SaveFunction; | |||||
var restore_fn_id = saveable_object_proto.RestoreFunction; | |||||
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||||
} | |||||
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Load all saved objects. | |||||
/// </summary> | |||||
private void _load_nodes() | |||||
{ | |||||
// `nodes` maps from node ids to recreated objects | |||||
// `node_setters` maps from node ids to setter functions | |||||
// (same signature as setattr) for setting children. | |||||
var (nodes, node_setters) = _initialize_loaded_nodes(); | |||||
Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)> | |||||
slot_variable_node_ids = new(); | |||||
foreach(var (node_id, proto) in _iter_all_nodes()) | |||||
{ | |||||
foreach(var slot_variable_proto in proto.SlotVariables) | |||||
{ | |||||
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||||
slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); | |||||
} | |||||
} | |||||
// Re-create everything. | |||||
foreach (var (node_id, proto) in _iter_all_nodes()) | |||||
{ | |||||
if (nodes.ContainsKey(node_id)) | |||||
{ | |||||
continue; | |||||
} | |||||
else if (slot_variable_node_ids.ContainsKey(node_id)) | |||||
{ | |||||
// Use the public Optimizer interface when creating slot variables. | |||||
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||||
var optimizer_object = nodes[optimizer_node_id]; | |||||
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||||
// TODO: implement it. | |||||
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||||
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
} | |||||
else | |||||
{ | |||||
// skip the function and concrete function. | |||||
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
{ | |||||
nodes[node_id] = null; | |||||
node_setters[node_id] = null; | |||||
continue; | |||||
} | |||||
var (node, setter) = _recreate(proto, node_id, nodes); | |||||
nodes[node_id] = node; | |||||
node_setters[node_id] = setter; | |||||
} | |||||
} | |||||
if (!nodes.ContainsKey(0)) | |||||
{ | |||||
nodes[0] = _recreate_base_user_object().Item1; | |||||
} | |||||
_nodes = new List<Trackable>(); | |||||
for(int i = 0; i < _proto.Nodes.Count; i++) | |||||
{ | |||||
_nodes.Add(nodes[i]); | |||||
} | |||||
_node_setters = node_setters; | |||||
} | |||||
/// <summary> | |||||
/// Load state from checkpoint into the deserialized objects. | |||||
/// </summary> | |||||
private void _restore_checkpoint() | |||||
{ | |||||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||||
tf.device("CPU"); | |||||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
LoadStatus load_status; | |||||
if (_save_options.allow_partial_checkpoint) | |||||
{ | |||||
load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); | |||||
load_status.assert_nontrivial_match(); | |||||
} | |||||
else | |||||
{ | |||||
load_status = saver.restore(variables_path, _checkpoint_options); | |||||
load_status.assert_existing_objects_matched(); | |||||
} | |||||
var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + | |||||
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Adds edges from objects to other objects and functions. | |||||
/// </summary> | |||||
private void _load_edges() | |||||
{ | |||||
foreach(var (node_id, object_proto) in _iter_all_nodes()) | |||||
{ | |||||
_add_object_graph_edges(object_proto, node_id); | |||||
} | |||||
if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) | |||||
{ | |||||
var root = get(0); | |||||
foreach(var node_path in _node_filters.Keys) | |||||
{ | |||||
var loaded_node = _nodes[_node_path_to_id[node_path]]; | |||||
var path = node_path.Split('.'); | |||||
var current_node = root; | |||||
foreach(var name in path.Skip(1).Take(path.Length - 2)) | |||||
{ | |||||
// `hasattr` and `setattr` is used here | |||||
throw new NotImplementedException(); | |||||
} | |||||
// `hasattr` and `setattr` is used here | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} | |||||
private void _setup_remaining_functions() | |||||
{ | |||||
// TODO: implement it with concrete functions. | |||||
} | |||||
public Trackable get(int node_id) | |||||
{ | |||||
return _nodes[node_id]; | |||||
} | |||||
public Trackable get(string node_id) | |||||
{ | |||||
return get(_node_path_to_id[node_id]); | |||||
} | |||||
/// <summary> | |||||
/// Adds edges from an object to its children. | |||||
/// </summary> | |||||
/// <param name="proto"></param> | |||||
/// <param name="node_id"></param> | |||||
private void _add_object_graph_edges(SavedObject proto, int node_id) | |||||
{ | |||||
var obj = _nodes[node_id]; | |||||
var setter = _node_setters[node_id]; | |||||
foreach(var refer in proto.Children) | |||||
{ | |||||
if(obj is null) | |||||
{ | |||||
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
continue; | |||||
} | |||||
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||||
// skip the process of "__call__" | |||||
} | |||||
} | |||||
private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||||
{ | |||||
Dictionary<int, Trackable> nodes = new(); | |||||
Dictionary<int, Action<object, object, object>> node_setters = new(); | |||||
foreach(var item in _loaded_nodes) | |||||
{ | |||||
var node_id = item.Key; | |||||
var (node, setter) = item.Value; | |||||
nodes[node_id] = node; | |||||
node_setters[node_id] = setter; | |||||
} | |||||
return (nodes, node_setters); | |||||
} | |||||
private IEnumerable<(int, SavedObject)> _iter_all_nodes() | |||||
{ | |||||
foreach(var node_id in _ordered_node_ids) | |||||
{ | |||||
yield return (node_id, _proto.Nodes[node_id]); | |||||
} | |||||
} | |||||
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||||
{ | |||||
// skip the registered classes. | |||||
Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||||
foreach(var item in _get_node_dependencies(proto)) | |||||
{ | |||||
dependencies[item.Key] = nodes[item.Value]; | |||||
} | |||||
return _recreate_default(proto, node_id, dependencies); | |||||
} | |||||
/// <summary> | |||||
/// Creates a Python object from a SavedObject protocol buffer. | |||||
/// </summary> | |||||
/// <param name="proto"></param> | |||||
/// <param name="node_id"></param> | |||||
/// <param name="dependencies"></param> | |||||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||||
{ | |||||
return proto.KindCase switch | |||||
{ | |||||
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||||
SavedObject.KindOneofCase.Function => throw new NotImplementedException(), | |||||
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||||
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||||
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||||
}; | |||||
} | |||||
private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||||
{ | |||||
// skip the check of proto identifier because of lack of property. | |||||
var looked_up = RevivedTypes.deserialize(proto); | |||||
if(looked_up is null) | |||||
{ | |||||
return _recreate_base_user_object(proto, node_id); | |||||
} | |||||
return (looked_up.Item1, looked_up.Item2); | |||||
} | |||||
private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||||
{ | |||||
return (new _UserObject(), setattr); | |||||
} | |||||
private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto) | |||||
{ | |||||
string name = proto.Name; | |||||
string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>"; | |||||
// TODO(Rinne): `validate_synchronization_aggregation_trainable` | |||||
var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( | |||||
proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); | |||||
var saved_device = proto.Device; | |||||
var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); | |||||
if (load_with_device) | |||||
{ | |||||
tf.device(saved_device); | |||||
return (new UninitializedVariable( | |||||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
dtype: (TF_DataType)proto.Dtype, | |||||
name: name, | |||||
trainable: trainable, | |||||
aggregation: aggregation | |||||
), setattr); | |||||
} | |||||
else | |||||
{ | |||||
return (new UninitializedVariable( | |||||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
dtype: (TF_DataType)proto.Dtype, | |||||
name: name, | |||||
trainable: trainable, | |||||
aggregation: aggregation | |||||
), setattr); | |||||
} | |||||
} | |||||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||||
Dictionary<Maybe<string, int>, Trackable> dependencies) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
//var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||||
} | |||||
// TODO: remove this to a common class. | |||||
public static Action<object, object, object> setattr = (x, y, z) => | |||||
{ | |||||
Debug.Assert(y is string); | |||||
var properties = x.GetType().GetProperties(); | |||||
foreach(var p in properties) | |||||
{ | |||||
if((string)y == p.Name) | |||||
{ | |||||
p.SetValue(x, z); | |||||
return; | |||||
} | |||||
} | |||||
// TODO(Rinne): check if the property has been set successfully. | |||||
//throw new ValueError($"Cannot find the property {y} of {x}."); | |||||
}; | |||||
public class _UserObject: AutoTrackable | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,122 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Loader | |||||
{ | |||||
public static SavedModel parse_saved_model(string export_dir) | |||||
{ | |||||
var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||||
var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); | |||||
SavedModel saved_model = new SavedModel(); | |||||
if (File.Exists(path_to_pb)) | |||||
{ | |||||
byte[] file_content; | |||||
using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) | |||||
{ | |||||
file_content = new byte[f.Length]; | |||||
Debug.Assert(f.Length <= int.MaxValue); | |||||
f.Read(file_content, 0, (int)f.Length); | |||||
} | |||||
// TODO: change to stream mode. | |||||
saved_model.MergeFrom(file_content); | |||||
return saved_model; | |||||
} | |||||
else if (File.Exists(path_to_pbtxt)) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + | |||||
$"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); | |||||
} | |||||
} | |||||
// TODO: revise the type of `tags` | |||||
public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) | |||||
{ | |||||
return load_partial(export_dir, null, tags, options)["root"]; | |||||
} | |||||
public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null) | |||||
{ | |||||
if (options is null) | |||||
{ | |||||
options = new LoadOptions(); | |||||
} | |||||
if (tags is not null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); | |||||
Trackable root = null; | |||||
Loader loader = null; | |||||
if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) | |||||
{ | |||||
// skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` | |||||
var meta_graph_def = saved_model_proto.MetaGraphs[0]; | |||||
if (!BitConverter.IsLittleEndian) | |||||
{ | |||||
SavedModelUtils.swap_function_tensor_content(meta_graph_def); | |||||
} | |||||
var object_graph_proto = meta_graph_def.ObjectGraphDef; | |||||
var ckpt_options = new CheckpointOptions(options.experimental_io_device); | |||||
tf_with(ops.init_scope(), x => | |||||
{ | |||||
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||||
root = loader.get(0); | |||||
// skip the assignment of `graph_debug_info`. | |||||
}); | |||||
// skip the assignment of `tensorflow_version` | |||||
// skip the assignment of `tensorflow_git_version` | |||||
// skip the process of `metrics`. | |||||
} | |||||
else | |||||
{ | |||||
if(filters is not null && filters.Count > 0) | |||||
{ | |||||
throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" | |||||
+ " version) cannot be loaded with node filters."); | |||||
} | |||||
tf_with(ops.init_scope(), x => | |||||
{ | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
}); | |||||
} | |||||
if(filters != null && filters.Count > 0) | |||||
{ | |||||
return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||||
} | |||||
else | |||||
{ | |||||
var res = new Dictionary<string, Trackable>(); | |||||
res["root"] = root; | |||||
return res; | |||||
} | |||||
} | |||||
public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) | |||||
{ | |||||
var saved_model = parse_saved_model(export_dir); | |||||
// TODO: implement debug info. | |||||
return (saved_model, null); | |||||
} | |||||
} | |||||
} |
@@ -6,7 +6,6 @@ using System.Text; | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -1,7 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.ModelSaving; | |||||
namespace Tensorflow.Training.Saving.SavedModel | namespace Tensorflow.Training.Saving.SavedModel | ||||
{ | { | ||||
@@ -68,6 +68,34 @@ namespace Tensorflow | |||||
return saveables.ToArray(); | return saveables.ToArray(); | ||||
} | } | ||||
public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables) | |||||
{ | |||||
var saveables = new List<MySaveableObject>(); | |||||
var seen_ops = new List<Tensor>(); | |||||
foreach (var (name, op) in enumerate(names_to_saveables)) | |||||
{ | |||||
foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) | |||||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||||
} | |||||
return saveables.ToArray(); | |||||
} | |||||
public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables) | |||||
{ | |||||
var saveables = new List<MySaveableObject>(); | |||||
var seen_ops = new List<BaseResourceVariable>(); | |||||
foreach(var item in names_to_saveables.OrderBy(x => x.Key)) | |||||
{ | |||||
foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||||
{ | |||||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||||
} | |||||
} | |||||
return saveables.ToArray(); | |||||
} | |||||
private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | ||||
{ | { | ||||
if (seen_ops.Contains(saveable.op)) | if (seen_ops.Contains(saveable.op)) | ||||
@@ -77,6 +105,15 @@ namespace Tensorflow | |||||
seen_ops.Add(saveable.op); | seen_ops.Add(saveable.op); | ||||
} | } | ||||
private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable) | |||||
{ | |||||
if (seen_ops.Contains(saveable.variable)) | |||||
throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); | |||||
saveables.Add(saveable); | |||||
seen_ops.Add(saveable.variable); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | ||||
/// </summary> | /// </summary> | ||||
@@ -136,19 +173,20 @@ namespace Tensorflow | |||||
{ | { | ||||
full_name = name + "_" + attr; | full_name = name + "_" + attr; | ||||
} | } | ||||
if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||||
var op = factory(full_name); | |||||
if(op.TryGet<BaseResourceVariable>(out var variable)) | |||||
{ | { | ||||
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
{ | { | ||||
yield return op; | |||||
yield return v; | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
var saveable = factory.GetValue<MySaveableObject>(); | |||||
foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||||
var saveable = op.GetValue<MySaveableObject>(); | |||||
foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||||
{ | { | ||||
yield return op; | |||||
yield return v; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -214,20 +252,19 @@ namespace Tensorflow | |||||
return names_to_saveables; | return names_to_saveables; | ||||
} | } | ||||
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||||
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||||
{ | { | ||||
// skip the process of type `PythonState` | // skip the process of type `PythonState` | ||||
if (trackable_has_serialize_to_tensor(obj)) | |||||
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
{ | { | ||||
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||||
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | ||||
var tensor_dict = obj.serialize_to_tensors(); | var tensor_dict = obj.serialize_to_tensors(); | ||||
List<SaveSpec> specs = new(); | List<SaveSpec> specs = new(); | ||||
List<string> local_names = new(); | List<string> local_names = new(); | ||||
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | ||||
foreach(var pair in tensor_dict) | |||||
foreach (var pair in tensor_dict) | |||||
{ | { | ||||
var tensor_name = pair.Key; | var tensor_name = pair.Key; | ||||
var maybe_tensor = pair.Value; | var maybe_tensor = pair.Value; | ||||
@@ -235,9 +272,9 @@ namespace Tensorflow | |||||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | ||||
IDictionary<string, Tensor> internal_dict; | IDictionary<string, Tensor> internal_dict; | ||||
if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
{ | { | ||||
internal_dict= new Dictionary<string, Tensor>(); | |||||
internal_dict = new Dictionary<string, Tensor>(); | |||||
internal_dict[""] = tensor; | internal_dict[""] = tensor; | ||||
} | } | ||||
else | else | ||||
@@ -245,13 +282,18 @@ namespace Tensorflow | |||||
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | ||||
} | } | ||||
foreach(var item in internal_dict) | |||||
foreach (var item in internal_dict) | |||||
{ | { | ||||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | ||||
} | } | ||||
} | } | ||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||||
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
} | |||||
if (trackable_has_serialize_to_tensor(obj)) | |||||
{ | |||||
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||||
return res; | return res; | ||||
} | } | ||||
else | else | ||||
@@ -333,6 +375,28 @@ namespace Tensorflow | |||||
return restored_ops; | return restored_ops; | ||||
}; | }; | ||||
} | } | ||||
/// <summary> | |||||
/// Returns a dict of SaveableObject factories generated from loaded fns. | |||||
/// </summary> | |||||
/// <param name="saveable_fn_by_name"></param> | |||||
/// <param name="temp_session"></param> | |||||
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||||
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||||
{ | |||||
if (saveable_fn_by_name.Count > 0) | |||||
{ | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
} | |||||
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
return res; | |||||
} | |||||
public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
bool call_with_mapped_captures = false) | |||||
{ | |||||
return factory(key); | |||||
} | |||||
} | } | ||||
public class SaveableCompatibilityConverter: Trackable | public class SaveableCompatibilityConverter: Trackable | ||||
@@ -20,8 +20,8 @@ using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using Tensorflow.Training.Saving.SavedModel; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
@@ -41,9 +41,10 @@ namespace Tensorflow.Train | |||||
protected IDictionary<string, Trackable> _unconditional_dependency_names; | protected IDictionary<string, Trackable> _unconditional_dependency_names; | ||||
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | ||||
protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||||
protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||||
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||||
new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
private bool _manual_tracking = true; | private bool _manual_tracking = true; | ||||
private static Trackable _none = new AutoTrackable(); | private static Trackable _none = new AutoTrackable(); | ||||
@@ -71,6 +72,18 @@ namespace Tensorflow.Train | |||||
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | ||||
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | ||||
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | ||||
public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||||
public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||||
{ | |||||
get | |||||
{ | |||||
return _self_saveable_object_factories; | |||||
} | |||||
set | |||||
{ | |||||
_self_saveable_object_factories = value; | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
@@ -136,9 +149,11 @@ namespace Tensorflow.Train | |||||
_self_update_uid = -1; | _self_update_uid = -1; | ||||
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); | _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | ||||
_unconditional_dependency_names = new Dictionary<string, Trackable>(); | _unconditional_dependency_names = new Dictionary<string, Trackable>(); | ||||
_unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>(); | |||||
} | } | ||||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, | |||||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
_maybe_initialize_trackable(); | _maybe_initialize_trackable(); | ||||
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | ||||
@@ -174,10 +189,19 @@ namespace Tensorflow.Train | |||||
/// <param name="trackable"></param> | /// <param name="trackable"></param> | ||||
public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | ||||
{ | { | ||||
//_maybe_initialize_trackable(); | |||||
//trackable._maybe_initialize_trackable(); | |||||
// TODO: complete the implementation. | |||||
_maybe_initialize_trackable(); | |||||
trackable._maybe_initialize_trackable(); | |||||
if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) | |||||
{ | |||||
_unconditional_deferred_dependencies.Remove(name); | |||||
foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) | |||||
{ | |||||
checkpoint_position.restore(trackable); | |||||
} | |||||
} | |||||
// TODO(Rinne): deal with `_self_name_based_restores` | |||||
} | } | ||||
public virtual Trackable? _lookup_dependency(string name) | public virtual Trackable? _lookup_dependency(string name) | ||||
@@ -225,12 +249,19 @@ namespace Tensorflow.Train | |||||
return self_tensor_map.Keys.ToList(); | return self_tensor_map.Keys.ToList(); | ||||
} | } | ||||
public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
{ | { | ||||
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
{ | |||||
throw new NotImplementedException(); | |||||
//return new TrackableSaveable(this, null, name, null, null); | |||||
} | |||||
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | ||||
{ | { | ||||
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | ||||
throw new NotImplementedException(); | |||||
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
res[""] = create_saveable; | |||||
return res; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -259,4 +290,6 @@ namespace Tensorflow.Train | |||||
} | } | ||||
public record class TrackableReference(string Name, Trackable Refer); | public record class TrackableReference(string Name, Trackable Refer); | ||||
public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); | |||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
@@ -20,9 +21,9 @@ public static class TrackableUtils | |||||
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | ||||
} | } | ||||
} | } | ||||
private static string _ESCAPE_CHAR = "."; | |||||
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
internal static string _ESCAPE_CHAR = "."; | |||||
internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | ||||
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | ||||
{ | { | ||||
@@ -5,9 +5,9 @@ using Tensorflow.Variables; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.ModelSaving; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -19,7 +19,11 @@ namespace Tensorflow | |||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
protected string _handle_name; | protected string _handle_name; | ||||
protected string handle_name => _handle_name; | |||||
public string handle_name | |||||
{ | |||||
get { return _handle_name; } | |||||
set { _handle_name = value; } | |||||
} | |||||
protected string _unique_id; | protected string _unique_id; | ||||
public string UniqueId => _unique_id; | public string UniqueId => _unique_id; | ||||
@@ -289,10 +293,10 @@ namespace Tensorflow | |||||
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | ||||
} | } | ||||
public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
{ | { | ||||
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||||
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||||
return res; | return res; | ||||
} | } | ||||
@@ -238,5 +238,23 @@ namespace Tensorflow | |||||
{ | { | ||||
return _graph_element.eval(session); | return _graph_element.eval(session); | ||||
} | } | ||||
public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( | |||||
VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) | |||||
{ | |||||
if(aggregation is null) | |||||
{ | |||||
aggregation = VariableAggregation.None; | |||||
} | |||||
if(synchronization is null) | |||||
{ | |||||
synchronization = VariableSynchronization.Auto; | |||||
} | |||||
if (trainable is null) | |||||
{ | |||||
trainable = synchronization != VariableSynchronization.OnRead; | |||||
} | |||||
return (synchronization.Value, aggregation.Value, trainable.Value); | |||||
} | |||||
} | } | ||||
} | } |
@@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
/// <param name="config"></param> | /// <param name="config"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||||
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||||
{ | { | ||||
// Layer instances created during the graph reconstruction process. | // Layer instances created during the graph reconstruction process. | ||||
var created_layers = new Dictionary<string, ILayer>(); | |||||
created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||||
var node_index_map = new Dictionary<(string, int), int>(); | var node_index_map = new Dictionary<(string, int), int>(); | ||||
var node_count_by_layer = new Dictionary<ILayer, int>(); | var node_count_by_layer = new Dictionary<ILayer, int>(); | ||||
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | ||||
@@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine | |||||
layer = created_layers[layer_name]; | layer = created_layers[layer_name]; | ||||
else | else | ||||
{ | { | ||||
layer = layer_data.ClassName switch | |||||
{ | |||||
"InputLayer" => InputLayer.from_config(layer_data.Config), | |||||
"Dense" => Dense.from_config(layer_data.Config), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); | |||||
created_layers[layer_name] = layer; | created_layers[layer_name] = layer; | ||||
} | } | ||||
@@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine | |||||
Inputs = inputs, | Inputs = inputs, | ||||
Outputs = outputs | Outputs = outputs | ||||
}) | }) | ||||
{ | |||||
Initialize(inputs, outputs, name); | |||||
} | |||||
internal void Initialize(Tensors inputs, Tensors outputs, string name = null) | |||||
{ | { | ||||
_input_layers = new List<ILayer>(); | _input_layers = new List<ILayer>(); | ||||
_output_layers = new List<ILayer>(); | _output_layers = new List<ILayer>(); | ||||
@@ -70,7 +75,14 @@ namespace Tensorflow.Keras.Engine | |||||
this.inputs = inputs; | this.inputs = inputs; | ||||
this.outputs = outputs; | this.outputs = outputs; | ||||
built = true; | built = true; | ||||
_buildInputShape = inputs.shape; | |||||
if(inputs.Length > 0) | |||||
{ | |||||
_buildInputShape = inputs.shape; | |||||
} | |||||
else | |||||
{ | |||||
_buildInputShape = new Saving.TensorShapeConfig(); | |||||
} | |||||
if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine | |||||
public virtual Shape ComputeOutputShape(Shape input_shape) | public virtual Shape ComputeOutputShape(Shape input_shape) | ||||
=> throw new NotImplementedException(""); | => throw new NotImplementedException(""); | ||||
protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) | |||||
{ | |||||
List<IVariableV1> res = new(); | |||||
var nested_layers = _flatten_layers(false, false); | |||||
foreach (var layer in nested_layers) | |||||
{ | |||||
if (layer is Layer l) | |||||
{ | |||||
if (include_trainable == true && include_non_trainable == true) | |||||
{ | |||||
res.AddRange(l.Variables); | |||||
} | |||||
else if (include_trainable == true && include_non_trainable == false) | |||||
{ | |||||
res.AddRange(l.TrainableVariables); | |||||
} | |||||
else if(include_trainable == false && include_non_trainable == true) | |||||
{ | |||||
res.AddRange(l.NonTrainableVariables); | |||||
} | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
} | } | ||||
} | } |
@@ -12,7 +12,7 @@ public abstract partial class Layer | |||||
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | ||||
public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||||
public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | ||||
{ | { | ||||
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Newtonsoft.Json.Linq; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
@@ -66,16 +67,74 @@ namespace Tensorflow.Keras.Engine | |||||
public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||||
public virtual List<IVariableV1> TrainableVariables => TrainableWeights; | |||||
protected List<IVariableV1> _non_trainable_weights; | protected List<IVariableV1> _non_trainable_weights; | ||||
public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | |||||
public List<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||||
public List<IVariableV1> Variables => Weights; | |||||
public virtual List<IVariableV1> TrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
if (!this.Trainable) | |||||
{ | |||||
return new List<IVariableV1>(); | |||||
} | |||||
var children_weights = _gather_children_variables(true); | |||||
return children_weights.Concat(_trainable_weights).Distinct().ToList(); | |||||
} | |||||
} | |||||
public virtual List<IVariableV1> NonTrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
if (!this.Trainable) | |||||
{ | |||||
var children_weights = _gather_children_variables(true, true); | |||||
return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); | |||||
} | |||||
else | |||||
{ | |||||
var children_weights = _gather_children_variables(include_non_trainable: true); | |||||
return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); | |||||
} | |||||
} | |||||
} | |||||
public virtual List<IVariableV1> Weights | |||||
{ | |||||
get | |||||
{ | |||||
return TrainableWeights.Concat(NonTrainableWeights).ToList(); | |||||
} | |||||
set | |||||
{ | |||||
if (Weights.Count() != value.Count()) throw new ValueError( | |||||
$"You called `set_weights` on layer \"{this.name}\"" + | |||||
$"with a weight list of length {len(value)}, but the layer was " + | |||||
$"expecting {len(Weights)} weights."); | |||||
foreach (var (this_w, v_w) in zip(Weights, value)) | |||||
this_w.assign(v_w, read_value: true); | |||||
} | |||||
} | |||||
protected int id; | protected int id; | ||||
public int Id => id; | public int Id => id; | ||||
protected string name; | protected string name; | ||||
protected string base_name; | protected string base_name; | ||||
public string Name => name; | |||||
public string Name | |||||
{ | |||||
get | |||||
{ | |||||
return name; | |||||
} | |||||
set | |||||
{ | |||||
name = value; | |||||
} | |||||
} | |||||
protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
protected List<Operation> updates; | protected List<Operation> updates; | ||||
@@ -85,10 +144,11 @@ namespace Tensorflow.Keras.Engine | |||||
List<INode> inboundNodes; | List<INode> inboundNodes; | ||||
public List<INode> InboundNodes => inboundNodes; | public List<INode> InboundNodes => inboundNodes; | ||||
List<INode> outboundNodes; | List<INode> outboundNodes; | ||||
public List<INode> OutboundNodes => outboundNodes; | public List<INode> OutboundNodes => outboundNodes; | ||||
public JObject SerializedAttributes { get; set; } | |||||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ||||
public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
public Tensor[] input | public Tensor[] input | ||||
@@ -117,6 +177,11 @@ namespace Tensorflow.Keras.Engine | |||||
protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
{ | |||||
Initialize(args); | |||||
} | |||||
internal virtual void Initialize(LayerArgs args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
// A stateful layer is a layer whose updates are run during inference too, | // A stateful layer is a layer whose updates are run during inference too, | ||||
@@ -273,46 +338,9 @@ namespace Tensorflow.Keras.Engine | |||||
public int count_params() | public int count_params() | ||||
{ | { | ||||
if (Trainable) | if (Trainable) | ||||
return layer_utils.count_params(this, weights); | |||||
return layer_utils.count_params(this, Weights); | |||||
return 0; | return 0; | ||||
} | } | ||||
List<IVariableV1> ILayer.TrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
return _trainable_weights; | |||||
} | |||||
} | |||||
List<IVariableV1> ILayer.NonTrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
return _non_trainable_weights; | |||||
} | |||||
} | |||||
public List<IVariableV1> weights | |||||
{ | |||||
get | |||||
{ | |||||
var weights = new List<IVariableV1>(); | |||||
weights.AddRange(_trainable_weights); | |||||
weights.AddRange(_non_trainable_weights); | |||||
return weights; | |||||
} | |||||
set | |||||
{ | |||||
if (weights.Count() != value.Count()) throw new ValueError( | |||||
$"You called `set_weights` on layer \"{this.name}\"" + | |||||
$"with a weight list of length {len(value)}, but the layer was " + | |||||
$"expecting {len(weights)} weights."); | |||||
foreach (var (this_w, v_w) in zip(weights, value)) | |||||
this_w.assign(v_w, read_value: true); | |||||
} | |||||
} | |||||
public List<IVariableV1> Variables => weights; | |||||
public virtual IKerasConfig get_config() | public virtual IKerasConfig get_config() | ||||
=> args; | => args; | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
using (SharedObjectSavingScope.Enter()) | using (SharedObjectSavingScope.Enter()) | ||||
{ | { | ||||
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine | |||||
IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
bool _base_model_initialized; | bool _base_model_initialized; | ||||
bool stop_training; | bool stop_training; | ||||
public bool IsGraphNetwork => _is_graph_network; | |||||
public OptimizerV2 Optimizer | public OptimizerV2 Optimizer | ||||
{ | { | ||||
@@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine | |||||
_init_batch_counters(); | _init_batch_counters(); | ||||
} | } | ||||
internal override void Initialize(LayerArgs args) | |||||
{ | |||||
_init_batch_counters(); | |||||
base.Initialize(args); | |||||
} | |||||
void _configure_steps_per_execution(int steps_per_execution) | void _configure_steps_per_execution(int steps_per_execution) | ||||
{ | { | ||||
_steps_per_execution = tf.Variable(steps_per_execution, | _steps_per_execution = tf.Variable(steps_per_execution, | ||||
@@ -81,10 +89,11 @@ namespace Tensorflow.Keras.Engine | |||||
public override List<ILayer> Layers | public override List<ILayer> Layers | ||||
=> _flatten_layers(recursive: false, include_self: false).ToList(); | => _flatten_layers(recursive: false, include_self: false).ToList(); | ||||
public override List<IVariableV1> TrainableVariables | |||||
public override List<IVariableV1> TrainableWeights | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
// skip the assertion of weights created. | |||||
var variables = new List<IVariableV1>(); | var variables = new List<IVariableV1>(); | ||||
if (!Trainable) | if (!Trainable) | ||||
@@ -95,18 +104,40 @@ namespace Tensorflow.Keras.Engine | |||||
foreach (var trackable_obj in _self_tracked_trackables) | foreach (var trackable_obj in _self_tracked_trackables) | ||||
{ | { | ||||
if (trackable_obj.Trainable) | if (trackable_obj.Trainable) | ||||
variables.AddRange(trackable_obj.TrainableVariables); | |||||
variables.AddRange(trackable_obj.TrainableWeights); | |||||
} | } | ||||
foreach (var layer in _self_tracked_trackables) | |||||
variables.AddRange(_trainable_weights); | |||||
return variables.Distinct().ToList(); | |||||
} | |||||
} | |||||
public override List<IVariableV1> NonTrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
// skip the assertion of weights created. | |||||
var variables = new List<IVariableV1>(); | |||||
foreach (var trackable_obj in _self_tracked_trackables) | |||||
{ | { | ||||
if (layer.Trainable) | |||||
variables.AddRange(layer.TrainableVariables); | |||||
variables.AddRange(trackable_obj.NonTrainableWeights); | |||||
} | } | ||||
// variables.AddRange(_trainable_weights); | |||||
if (!Trainable) | |||||
{ | |||||
var trainable_variables = new List<IVariableV1>(); | |||||
foreach (var trackable_obj in _self_tracked_trackables) | |||||
{ | |||||
variables.AddRange(trackable_obj.TrainableWeights); | |||||
} | |||||
variables.AddRange(trainable_variables); | |||||
variables.AddRange(_trainable_weights); | |||||
variables.AddRange(_non_trainable_weights); | |||||
} | |||||
return variables; | |||||
return variables.Distinct().ToList(); | |||||
} | } | ||||
} | } | ||||
@@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine | |||||
: base(args.Inputs, args.Outputs, name: args.Name) | : base(args.Inputs, args.Outputs, name: args.Name) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
if (args.Layers == null) | |||||
args.Layers = new List<ILayer>(); | |||||
// SupportsMasking = true; | // SupportsMasking = true; | ||||
_compute_output_and_mask_jointly = true; | _compute_output_and_mask_jointly = true; | ||||
_auto_track_sub_layers = false; | _auto_track_sub_layers = false; | ||||
@@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine | |||||
_created_nodes = new List<INode>(); | _created_nodes = new List<INode>(); | ||||
// Add to the model any layers passed to the constructor. | // Add to the model any layers passed to the constructor. | ||||
if (args.Layers != null) | |||||
if (args.Layers is not null) | |||||
{ | { | ||||
foreach (var layer in args.Layers) | |||||
add(layer); | |||||
InitLayers(args.Layers); | |||||
} | |||||
} | |||||
public void InitLayers(IEnumerable<ILayer> layers) | |||||
{ | |||||
foreach(var layer in layers) | |||||
{ | |||||
add(layer); | |||||
} | } | ||||
} | } | ||||
@@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||||
{ | { | ||||
throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
} | } | ||||
_buildInputShape = input_shape; | |||||
built = true; | |||||
base.build(input_shape); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
@@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||||
} | } | ||||
public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
{ | { | ||||
_buildInputShape = input_shape; | |||||
built = true; | |||||
base.build(input_shape); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
{ | { | ||||
@@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { | |||||
if ( alpha < 0f ) { | if ( alpha < 0f ) { | ||||
throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
} | } | ||||
_buildInputShape = input_shape; | |||||
built = true; | |||||
base.build(input_shape); | |||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
@@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers | |||||
return outputs; | return outputs; | ||||
} | } | ||||
public static Dense from_config(LayerArgs args) | |||||
{ | |||||
return new Dense(args as DenseArgs); | |||||
} | |||||
} | } | ||||
} | } |
@@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers | |||||
name: Name); | name: Name); | ||||
} | } | ||||
public static InputLayer from_config(LayerArgs args) | |||||
{ | |||||
return new InputLayer(args as InputLayerArgs); | |||||
} | |||||
public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | ||||
} | } | ||||
} | } |
@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics | |||||
public virtual void reset_states() | public virtual void reset_states() | ||||
{ | { | ||||
foreach (var v in weights) | |||||
foreach (var v in Weights) | |||||
v.assign(0); | v.assign(0); | ||||
} | } | ||||
@@ -4,6 +4,7 @@ using System.IO; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
namespace Tensorflow.Keras.Models | namespace Tensorflow.Keras.Models | ||||
@@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models | |||||
public Functional from_config(ModelConfig config) | public Functional from_config(ModelConfig config) | ||||
=> Functional.from_config(config); | => Functional.from_config(config); | ||||
public void load_model(string filepath, bool compile = true) | |||||
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||||
{ | { | ||||
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||||
var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||||
var meta_graph_def = saved_mode.MetaGraphs[0]; | |||||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||||
var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||||
// Recreate layers and metrics using the info stored in the metadata. | |||||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
keras_loader.load_layers(compile: compile); | |||||
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,12 +1,24 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using Newtonsoft.Json.Linq; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.ComponentModel; | |||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection; | |||||
using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Keras.Utils; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
using static Tensorflow.ApiDef.Types; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving | |||||
{ | { | ||||
public class KerasObjectLoader | public class KerasObjectLoader | ||||
{ | { | ||||
SavedMetadata _metadata; | |||||
SavedObjectGraph _proto; | |||||
Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||||
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||||
List<int> _traversed_nodes_from_config = new List<int>(); | |||||
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||||
private SavedMetadata _metadata; | |||||
private SavedObjectGraph _proto; | |||||
private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||||
private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>(); | |||||
private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>(); | |||||
private List<int> _traversed_nodes_from_config = new List<int>(); | |||||
private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes; | |||||
private List<int> _models_to_reconstruct; | |||||
public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes; | |||||
static KerasObjectLoader() | |||||
{ | |||||
PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||||
} | |||||
public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | ||||
{ | { | ||||
_metadata = metadata; | _metadata = metadata; | ||||
_proto = object_graph_def; | _proto = object_graph_def; | ||||
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | ||||
_models_to_reconstruct = new List<int>(); | |||||
loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -42,15 +66,255 @@ namespace Tensorflow.Keras.Saving | |||||
continue; | continue; | ||||
} | } | ||||
_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||||
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||||
} | |||||
foreach(var node_metadata in metric_list) | |||||
{ | |||||
try | |||||
{ | |||||
if (node_metadata.Identifier.Equals("_tf_keras_metric")) | |||||
{ | |||||
continue; | |||||
} | |||||
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | |||||
node_metadata.Metadata); | |||||
} | |||||
catch(ValueError e) | |||||
{ | |||||
if (compile) | |||||
{ | |||||
throw e; | |||||
} | |||||
// TODO: add logging.warning. | |||||
} | |||||
} | |||||
} | |||||
public string get_path(int node_id) | |||||
{ | |||||
return _node_paths[node_id]; | |||||
} | |||||
/// <summary> | |||||
/// Finish setting up Keras objects. | |||||
/// | |||||
/// This function is executed after all objects and functions have been created. | |||||
/// Call functions and losses are attached to each layer, and once all layers | |||||
/// have been fully set up, graph networks are initialized. | |||||
/// | |||||
/// Subclassed models that are revived from the SavedModel are treated like | |||||
/// layers, and have their call/loss functions attached here. | |||||
/// </summary> | |||||
public void finalize_objects() | |||||
{ | |||||
List<Layer> layers_revived_from_config = new(); | |||||
List<Layer> layers_revived_from_saved_model = new(); | |||||
foreach(var item in loaded_nodes) | |||||
{ | |||||
var node_id = item.Key; | |||||
var node = item.Value.Item1; | |||||
if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) | |||||
{ | |||||
continue; | |||||
} | |||||
_unblock_model_reconstruction(node_id, node as Layer); | |||||
if(node is InputLayer or Metric) | |||||
{ | |||||
continue; | |||||
} | |||||
// TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||||
layers_revived_from_config.Add(node as Layer); | |||||
} | |||||
_finalize_saved_model_layers(layers_revived_from_saved_model); | |||||
_finalize_config_layers(layers_revived_from_config); | |||||
_reconstruct_all_models(); | |||||
} | |||||
private void _reconstruct_all_models() | |||||
{ | |||||
HashSet<int> all_initialized_models = new(); | |||||
for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) | |||||
{ | |||||
int model_id = _models_to_reconstruct[i]; | |||||
all_initialized_models.Add(model_id); | |||||
var (model, layers) = model_layer_dependencies[model_id]; | |||||
_reconstruct_model(model_id, model, layers.ToList()); | |||||
_finalize_config_layers(new List<Layer>() { model }); | |||||
} | |||||
Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); | |||||
} | |||||
private void _reconstruct_model(int model_id, Model model, List<Layer> layers) | |||||
{ | |||||
var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"]; | |||||
if(model.input is not null && model.input.Length > 0) | |||||
{ | |||||
} | |||||
else if(model is Sequential s) | |||||
{ | |||||
if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) | |||||
{ | |||||
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||||
{ | |||||
layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||||
} | |||||
else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||||
{ | |||||
// TODO(Rinne): implement it | |||||
} | |||||
} | |||||
// `model.__init__(layers, config["name"])` | |||||
s.InitLayers(layers); | |||||
s.Name = config["name"].ToObject<string>(); | |||||
if(s.input is null || s.input.Length == 0) | |||||
{ | |||||
var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||||
var input_specs = _infer_inputs(first_layer); | |||||
var input_shapes = _infer_inputs(first_layer, true); | |||||
// `model._set_inputs(input_specs)` | |||||
// skip the check of input_specs is Dictionary | |||||
if (!s.Built) | |||||
{ | |||||
s.build(input_shapes); | |||||
} | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
// skip the parameter `created_layers`. | |||||
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), | |||||
layers.ToDictionary(x => x.Name, x => x as ILayer)); | |||||
// skip the `model.__init__` | |||||
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||||
(model as Functional).connect_ancillary_layers(created_layers); | |||||
} | |||||
_set_network_attributes_from_metadata(model); | |||||
_unblock_model_reconstruction(model_id, model); | |||||
} | |||||
private void _set_network_attributes_from_metadata(Model revived_object) | |||||
{ | |||||
// TODO: implement it. | |||||
} | |||||
/// <summary> | |||||
/// Runs the final steps of loading Keras Layers from config. | |||||
/// </summary> | |||||
/// <param name="layers"></param> | |||||
private void _finalize_config_layers(List<Layer> layers) | |||||
{ | |||||
foreach(var layer in layers) | |||||
{ | |||||
if (_is_graph_network(layer)) | |||||
{ | |||||
_restore_layer_unconditional_losses(layer); | |||||
} | |||||
_restore_layer_activation_loss(layer); | |||||
_restore_layer_metrics(layer); | |||||
// TODO(Rinne): deal with RNN. | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Runs the final steps of loading Keras Layers from SavedModel. | |||||
/// </summary> | |||||
/// <param name="layers"></param> | |||||
private void _finalize_saved_model_layers(List<Layer> layers) | |||||
{ | |||||
foreach(var layer in layers) | |||||
{ | |||||
// TODO(Rinne): deal with `RevivedNetwork`. | |||||
_restore_layer_unconditional_losses(layer); | |||||
_restore_layer_activation_loss(layer); | |||||
_restore_layer_metrics(layer); | |||||
} | |||||
} | |||||
private void _restore_layer_unconditional_losses(Layer layer) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
private void _restore_layer_activation_loss(Layer layer) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
private void _restore_layer_metrics(Layer layer) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
/// <summary> | |||||
/// Removes layer from blocking model reconstruction. | |||||
/// </summary> | |||||
/// <param name="layer_id"></param> | |||||
/// <param name="layer"></param> | |||||
private void _unblock_model_reconstruction(int layer_id, Layer layer) | |||||
{ | |||||
foreach(var depencency in model_layer_ids_dependencies) | |||||
{ | |||||
var layer_ids = depencency.Value.Item2; | |||||
var layers = model_layer_dependencies.SetDefault(depencency.Key, | |||||
(depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; | |||||
if (!layer_ids.Contains(layer_id)) | |||||
{ | |||||
continue; | |||||
} | |||||
layers[Array.IndexOf(layer_ids, layer_id)] = layer; | |||||
if (layers.All(x => x is not null)) | |||||
{ | |||||
_models_to_reconstruct.Add(depencency.Key); | |||||
} | |||||
} | } | ||||
} | } | ||||
void _load_layer(int node_id, string identifier, string metadata_json) | |||||
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||||
{ | { | ||||
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | ||||
_revive_from_config(identifier, metadata, node_id); | |||||
if (loaded_nodes.ContainsKey(node_id)) | |||||
{ | |||||
var (node, setter) = loaded_nodes[node_id]; | |||||
_maybe_add_serialized_attributes(node as Layer, metadata); | |||||
var config = metadata.Config; | |||||
if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) | |||||
{ | |||||
Debug.Assert(node is Model); | |||||
var child_nodes = _get_child_layer_node_ids(node_id); | |||||
model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); | |||||
if(child_nodes is null || child_nodes.Length == 0) | |||||
{ | |||||
_models_to_reconstruct.Add(node_id); | |||||
} | |||||
} | |||||
return (node, setter); | |||||
} | |||||
else | |||||
{ | |||||
var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||||
if (obj is null) | |||||
{ | |||||
(obj, setter) = _revive_custom_object(identifier, metadata); | |||||
} | |||||
Debug.Assert(obj is Layer); | |||||
_maybe_add_serialized_attributes(obj as Layer, metadata); | |||||
return (obj, setter); | |||||
} | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -59,11 +323,34 @@ namespace Tensorflow.Keras.Saving | |||||
/// <param name="identifier"></param> | /// <param name="identifier"></param> | ||||
/// <param name="metadata"></param> | /// <param name="metadata"></param> | ||||
/// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||||
private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||||
{ | { | ||||
var obj = _revive_graph_network(identifier, metadata, node_id); | |||||
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||||
Trackable obj; | |||||
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
return (null, null); | |||||
//throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
} | |||||
else | |||||
{ | |||||
obj = _revive_graph_network(identifier, metadata, node_id); | |||||
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||||
} | |||||
if(obj is null) | |||||
{ | |||||
return (null, null); | |||||
} | |||||
var setter = _config_node_setter(_revive_setter); | |||||
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | ||||
return (obj, setter); | |||||
} | |||||
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | |||||
} | } | ||||
Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | ||||
@@ -71,6 +358,12 @@ namespace Tensorflow.Keras.Saving | |||||
var config = metadata.Config; | var config = metadata.Config; | ||||
var class_name = metadata.ClassName; | var class_name = metadata.ClassName; | ||||
Model model = null; | Model model = null; | ||||
if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") | |||||
{ | |||||
return null; | |||||
} | |||||
if (class_name == "Sequential") | if (class_name == "Sequential") | ||||
{ | { | ||||
model = new Sequential(new SequentialArgs | model = new Sequential(new SequentialArgs | ||||
@@ -78,34 +371,82 @@ namespace Tensorflow.Keras.Saving | |||||
Name = config.GetValue("name").ToString() | Name = config.GetValue("name").ToString() | ||||
}); | }); | ||||
} | } | ||||
else if (class_name == "Functional") | |||||
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
model = new Sequential(new SequentialArgs | |||||
{ | |||||
Name = class_name | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||||
} | } | ||||
if (!metadata.IsGraphNetwork) | |||||
return null; | |||||
// Record this model and its layers. This will later be used to reconstruct | // Record this model and its layers. This will later be used to reconstruct | ||||
// the model. | // the model. | ||||
var layers = _get_child_layer_node_ids(node_id); | var layers = _get_child_layer_node_ids(node_id); | ||||
model_layer_dependencies[node_id] = (model, layers); | |||||
model_layer_ids_dependencies[node_id] = (model, layers); | |||||
if(layers is null || layers.Length == 0) | |||||
{ | |||||
_models_to_reconstruct.Add(node_id); | |||||
} | |||||
return model; | return model; | ||||
} | } | ||||
Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||||
Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||||
{ | { | ||||
var config = metadata.Config; | var config = metadata.Config; | ||||
var class_name = metadata.ClassName; | var class_name = metadata.ClassName; | ||||
var shared_object_id = metadata.SharedObjectId; | var shared_object_id = metadata.SharedObjectId; | ||||
var must_restore_from_config = metadata.MustRestoreFromConfig; | var must_restore_from_config = metadata.MustRestoreFromConfig; | ||||
var obj = class_name switch | |||||
{ | |||||
"Resizing" => Resizing.from_config(config), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
var obj = generic_utils.deserialize_keras_object(class_name, config); | |||||
obj.Name = metadata.Name; | |||||
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||||
var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | ||||
return null; | |||||
if (!built) | |||||
{ | |||||
return null; | |||||
} | |||||
return obj; | |||||
} | |||||
private void _revive_setter(object layer, object name, object value) | |||||
{ | |||||
Debug.Assert(name is string); | |||||
Debug.Assert(layer is Layer); | |||||
if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||||
{ | |||||
if(value is Trackable) | |||||
{ | |||||
(layer as Layer)._track_trackable(value as Trackable, name as string); | |||||
} | |||||
if((layer as Layer).SerializedAttributes is null) | |||||
{ | |||||
(layer as Layer).SerializedAttributes = new JObject(); | |||||
} | |||||
(layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||||
} | |||||
else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||||
{ | |||||
(layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||||
} | |||||
else | |||||
{ | |||||
var properties = layer.GetType().GetProperties(); | |||||
foreach(var p in properties) | |||||
{ | |||||
if(p.Name == name as string && p.GetValue(layer) is not null) | |||||
{ | |||||
return; | |||||
} | |||||
} | |||||
Loader.setattr(layer, name, value); | |||||
} | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -143,34 +484,186 @@ namespace Tensorflow.Keras.Saving | |||||
/// <param name="obj"></param> | /// <param name="obj"></param> | ||||
/// <param name="proto"></param> | /// <param name="proto"></param> | ||||
/// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||||
void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) | |||||
{ | { | ||||
if (_traversed_nodes_from_config.Contains(node_id)) | if (_traversed_nodes_from_config.Contains(node_id)) | ||||
return; | return; | ||||
var parent_path = _node_paths[node_id]; | var parent_path = _node_paths[node_id]; | ||||
_traversed_nodes_from_config.Add(node_id); | _traversed_nodes_from_config.Add(node_id); | ||||
if (!obj.Built) | |||||
obj._maybe_initialize_trackable(); | |||||
if(obj is Layer layer && !layer.Built) | |||||
{ | { | ||||
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||||
_try_build_layer(obj, node_id, metadata.BuildInputShape); | |||||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata); | |||||
_try_build_layer(layer, node_id, metadata.BuildInputShape); | |||||
} | |||||
List<(Trackable, int, string)> children = new(); | |||||
foreach(var refer in proto.Children) | |||||
{ | |||||
var obj_child = obj._lookup_dependency(refer.LocalName); | |||||
children.Add((obj_child, refer.NodeId, refer.LocalName)); | |||||
} | |||||
var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||||
Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||||
}); | |||||
if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||||
{ | |||||
var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); | |||||
foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) | |||||
{ | |||||
if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) | |||||
{ | |||||
var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; | |||||
children.Add((metric as Metric, refer.NodeId, metric_path)); | |||||
} | |||||
} | |||||
} | |||||
foreach(var (obj_child, child_id, child_name) in children) | |||||
{ | |||||
if(obj_child is null) | |||||
{ | |||||
continue; | |||||
} | |||||
var child_proto = _proto.Nodes[child_id]; | |||||
// skip the check for registered identifier | |||||
Action<object, object, object> setter; | |||||
if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||||
{ | |||||
setter = _revive_setter; | |||||
} | |||||
else | |||||
{ | |||||
setter = Loader.setattr; | |||||
} | |||||
if (loaded_nodes.ContainsKey(child_id)) | |||||
{ | |||||
// skip the logging.warning | |||||
continue; | |||||
} | |||||
if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) | |||||
{ | |||||
(obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; | |||||
} | |||||
if(obj_child is TrackableDataStructure) | |||||
{ | |||||
setter = (x, y, z) => { }; | |||||
} | |||||
var child_path = $"{parent_path}.{child_name}"; | |||||
_node_paths[child_id] = child_path; | |||||
_add_children_recreated_from_config(obj_child, child_proto, child_id); | |||||
loaded_nodes[child_id] = (obj_child, setter); | |||||
} | } | ||||
} | } | ||||
bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) | |||||
private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||||
{ | { | ||||
if (obj.Built) | if (obj.Built) | ||||
return true; | return true; | ||||
if(build_input_shape is null) | |||||
{ | |||||
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||||
} | |||||
if(build_input_shape is not null) | |||||
{ | |||||
obj.build(build_input_shape); | |||||
// In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. | |||||
// On the one hand, C# does not support call a method from specified parent class. | |||||
// On the other hand, currently All class derived from Layer call `Layer.Build` or | |||||
// move the implementation of `Layer.build` to its own `build` method. | |||||
// Therefore we do not call it here. | |||||
// However, it's still quite risky once in the future a certain class derived from | |||||
// `Layer` does not call `Layer.build`. | |||||
return true; | |||||
} | |||||
return false; | return false; | ||||
} | } | ||||
bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||||
/// <summary> | |||||
/// Infers input shape of layer from SavedModel functions. | |||||
/// </summary> | |||||
/// <param name="layer_node_id"></param> | |||||
/// <param name="convert_to_shapes"></param> | |||||
/// <returns></returns> | |||||
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||||
{ | { | ||||
if (obj.Built) | |||||
return true; | |||||
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||||
if(call_fn_id is null) | |||||
{ | |||||
return null; | |||||
} | |||||
var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; | |||||
if(concrete_functions is null) | |||||
{ | |||||
return null; | |||||
} | |||||
var call_fn_name = concrete_functions[0]; | |||||
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
} | |||||
private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||||
{ | |||||
if(path_to_child is null || path_to_child.Count() == 0) | |||||
{ | |||||
return parent_id; | |||||
} | |||||
foreach(var child in _proto.Nodes[parent_id].Children) | |||||
{ | |||||
if(child.LocalName == path_to_child.First()) | |||||
{ | |||||
return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); | |||||
} | |||||
} | |||||
return null; | |||||
} | |||||
private bool _is_graph_network(Layer layer) | |||||
{ | |||||
// TODO: deal with `RevivedLayer` | |||||
if(layer is Functional) | |||||
{ | |||||
return (layer as Functional).IsGraphNetwork || layer is Sequential; | |||||
} | |||||
return false; | return false; | ||||
} | } | ||||
private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||||
{ | |||||
// TODO: deal with `RevivedLayer` | |||||
} | |||||
/// <summary> | |||||
/// Creates edges for nodes that are recreated from config. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
private Action<object, object, object> _config_node_setter(Action<object, object, object> setter) | |||||
{ | |||||
void setattr_wrapper(object obj, object name, object value) | |||||
{ | |||||
Debug.Assert(obj is Trackable); | |||||
Debug.Assert(name is string); | |||||
if((obj as Trackable)._lookup_dependency(name as string) is null) | |||||
{ | |||||
setter(obj, name, value); | |||||
} | |||||
} | |||||
return setattr_wrapper; | |||||
} | |||||
} | } | ||||
} | } |
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||||
public partial class KerasSavedModelUtils | public partial class KerasSavedModelUtils | ||||
{ | { | ||||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||||
public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||||
SaveOptions? options, bool save_traces = true) | SaveOptions? options, bool save_traces = true) | ||||
{ | { | ||||
if (!overwrite && File.Exists(filepath)) | if (!overwrite && File.Exists(filepath)) | ||||
@@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils | |||||
BadConsumers = { } | BadConsumers = { } | ||||
}, | }, | ||||
Identifier = layer.ObjectIdentifier, | Identifier = layer.ObjectIdentifier, | ||||
Metadata = layer.TrackingMetadata | |||||
Metadata = layer.GetTrackingMetadata() | |||||
}; | }; | ||||
metadata.Nodes.Add(saved_object); | metadata.Nodes.Add(saved_object); | ||||
@@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils | |||||
if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
})); | })); | ||||
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||||
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||||
{ | { | ||||
if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
@@ -0,0 +1,96 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Train; | |||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | |||||
public class KerasLoadModelUtils | |||||
{ | |||||
/// <summary> | |||||
/// Corresponding to keras/saving/save.py/load_model | |||||
/// </summary> | |||||
/// <param name="filepath"></param> | |||||
/// <param name="custom_objects"></param> | |||||
/// <param name="compile"></param> | |||||
/// <param name="options"></param> | |||||
/// <returns></returns> | |||||
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||||
bool compile = true, LoadOptions? options = null) | |||||
{ | |||||
using (SharedObjectSavingScope.Enter()) | |||||
{ | |||||
using (LoadContext.load_context(options)) | |||||
{ | |||||
if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||||
{ | |||||
throw new IOException($"No file or directory found at {filepath}."); | |||||
} | |||||
if (Directory.Exists(filepath)) | |||||
{ | |||||
return load(filepath, compile, options); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
{ | |||||
SavedMetadata metadata = new SavedMetadata(); | |||||
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||||
if (File.Exists(path_to_metadata_pb)) | |||||
{ | |||||
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
} | |||||
if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||||
{ | |||||
return Loader.load(path, options: options) as Model; | |||||
} | |||||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
keras_loader.load_layers(compile: compile); | |||||
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||||
nodes_to_load["root"] = (null, null); | |||||
foreach(var item in keras_loader.LoadedNodes) | |||||
{ | |||||
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||||
} | |||||
var loaded = Loader.load_partial(path, nodes_to_load, options); | |||||
keras_loader.finalize_objects(); | |||||
// keras_loader.del_tracking(); | |||||
var model = loaded["root"]; | |||||
if(model is Model && compile) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
} | |||||
return model; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,69 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.Threading; | |||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | |||||
// TODO: remove this class to common project. | |||||
public class ContextHandler: IDisposable | |||||
{ | |||||
public Action<bool> DisposeCallBack { get; set; } | |||||
public void Dispose() | |||||
{ | |||||
DisposeCallBack.Invoke(true); | |||||
} | |||||
} | |||||
public class LoadContext | |||||
{ | |||||
private bool _entered_load_context; | |||||
private LoadOptions? _load_options; | |||||
private static ThreadLocal<LoadContext> _load_context = new(); | |||||
private LoadContext() | |||||
{ | |||||
_entered_load_context = false; | |||||
_load_options = null; | |||||
} | |||||
public void set_load_options(LoadOptions load_options) | |||||
{ | |||||
_load_options = load_options; | |||||
_entered_load_context = true; | |||||
} | |||||
private void clear_load_options() | |||||
{ | |||||
_load_options = null; | |||||
_entered_load_context = false; | |||||
} | |||||
private LoadOptions? load_options() | |||||
{ | |||||
return _load_options; | |||||
} | |||||
public static ContextHandler load_context(LoadOptions? load_options) | |||||
{ | |||||
if(_load_context.Value is null) | |||||
{ | |||||
_load_context.Value = new LoadContext(); | |||||
} | |||||
_load_context.Value.set_load_options(load_options); | |||||
return new ContextHandler() | |||||
{ | |||||
DisposeCallBack = _ => _load_context.Value.clear_load_options() | |||||
}; | |||||
} | |||||
public static LoadOptions? get_load_option() | |||||
{ | |||||
return _load_context.Value.load_options(); | |||||
} | |||||
public static bool in_load_context() | |||||
{ | |||||
return _load_context.Value._entered_load_context; | |||||
} | |||||
} | |||||
} |
@@ -19,15 +19,21 @@ using Newtonsoft.Json.Linq; | |||||
using System; | using System; | ||||
using System.Collections; | using System.Collections; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Data; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
{ | { | ||||
public class generic_utils | public class generic_utils | ||||
{ | { | ||||
private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; | |||||
/// <summary> | /// <summary> | ||||
/// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | ||||
/// </summary> | /// </summary> | ||||
@@ -51,6 +57,58 @@ namespace Tensorflow.Keras.Utils | |||||
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | ||||
} | } | ||||
public static Layer deserialize_keras_object(string class_name, JToken config) | |||||
{ | |||||
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||||
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||||
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||||
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||||
var args = deserializationGenericMethod.Invoke(config, null); | |||||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||||
Debug.Assert(layer is Layer); | |||||
return layer as Layer; | |||||
} | |||||
public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||||
{ | |||||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||||
Debug.Assert(layer is Layer); | |||||
return layer as Layer; | |||||
} | |||||
public static LayerArgs deserialize_layer_args(string class_name, JToken config) | |||||
{ | |||||
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||||
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||||
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||||
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||||
var args = deserializationGenericMethod.Invoke(config, null); | |||||
Debug.Assert(args is LayerArgs); | |||||
return args as LayerArgs; | |||||
} | |||||
public static ModelConfig deserialize_model_config(JToken json) | |||||
{ | |||||
ModelConfig config = new ModelConfig(); | |||||
config.Name = json["name"].ToObject<string>(); | |||||
config.Layers = new List<LayerConfig>(); | |||||
var layersToken = json["layers"]; | |||||
foreach (var token in layersToken) | |||||
{ | |||||
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||||
config.Layers.Add(new LayerConfig() | |||||
{ | |||||
Config = args, | |||||
Name = token["name"].ToObject<string>(), | |||||
ClassName = token["class_name"].ToObject<string>(), | |||||
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||||
}); | |||||
} | |||||
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||||
config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>(); | |||||
return config; | |||||
} | |||||
public static string to_snake_case(string name) | public static string to_snake_case(string name) | ||||
{ | { | ||||
return string.Concat(name.Select((x, i) => | return string.Concat(name.Select((x, i) => | ||||
@@ -60,5 +118,15 @@ namespace Tensorflow.Keras.Utils | |||||
x.ToString(); | x.ToString(); | ||||
})).ToLower(); | })).ToLower(); | ||||
} | } | ||||
/// <summary> | |||||
/// Determines whether config appears to be a valid layer config. | |||||
/// </summary> | |||||
/// <param name="config"></param> | |||||
/// <returns></returns> | |||||
public static bool validate_config(JObject config) | |||||
{ | |||||
return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); | |||||
} | |||||
} | } | ||||
} | } |
@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils | |||||
} | } | ||||
var trainable_count = count_params(model, model.TrainableVariables); | var trainable_count = count_params(model, model.TrainableVariables); | ||||
var non_trainable_count = count_params(model, model.non_trainable_variables); | |||||
var non_trainable_count = count_params(model, model.NonTrainableVariables); | |||||
print($"Total params: {trainable_count + non_trainable_count}"); | print($"Total params: {trainable_count + non_trainable_count}"); | ||||
print($"Trainable params: {trainable_count}"); | print($"Trainable params: {trainable_count}"); | ||||
@@ -0,0 +1,9 @@ | |||||
´$root"_tf_keras_network*’${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2 | |||||
†root.layer-0"_tf_keras_input_layer*Ö{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2 | |||||
Íroot.layer-1"_tf_keras_layer*£{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2 | |||||
¯root.layer_with_weights-0"_tf_keras_layer*ø{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 | |||||
²root.layer_with_weights-1"_tf_keras_layer*û{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2 | |||||
Šroot.layer-4"_tf_keras_layer*à{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 | |||||
¹Troot.keras_api.metrics.0"_tf_keras_metric*‚{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2 | |||||
™Uroot.keras_api.metrics.1"_tf_keras_metric*â{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2 |
@@ -0,0 +1,68 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | |||||
using Tensorflow; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using static Tensorflow.KerasApi; | |||||
using Tensorflow.NumPy; | |||||
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; | |||||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||||
[TestClass] | |||||
public class SequentialModelLoad | |||||
{ | |||||
[TestMethod] | |||||
public void SimpleModelFromAutoCompile() | |||||
{ | |||||
var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile"); | |||||
model.summary(); | |||||
model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
// check the weights | |||||
var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy"); | |||||
var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy"); | |||||
Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second)); | |||||
Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second)); | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 8; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 50000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
} | |||||
[TestMethod] | |||||
public void AlexnetFromSequential() | |||||
{ | |||||
new SequentialModelSave().AlexnetFromSequential(); | |||||
var model = keras.models.load_model(@"./alexnet_from_sequential"); | |||||
model.summary(); | |||||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
var num_epochs = 1; | |||||
var batch_size = 8; | |||||
var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||||
} | |||||
} |
@@ -1,27 +1,21 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using System.Diagnostics; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Metrics; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Operations; | |||||
using System.Diagnostics; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | namespace TensorFlowNET.Keras.UnitTest.SaveModel; | ||||
[TestClass] | [TestClass] | ||||
public class SequentialModelTest | |||||
public class SequentialModelSave | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void SimpleModelFromAutoCompile() | public void SimpleModelFromAutoCompile() | ||||
@@ -63,6 +57,8 @@ public class SequentialModelTest | |||||
keras.layers.Softmax(1) | keras.layers.Softmax(1) | ||||
}); | }); | ||||
model.summary(); | |||||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | ||||
var data_loader = new MnistModelLoader(); | var data_loader = new MnistModelLoader(); | ||||
@@ -82,7 +78,7 @@ public class SequentialModelTest | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void AlexModelFromSequential() | |||||
public void AlexnetFromSequential() | |||||
{ | { | ||||
Model model = KerasApi.keras.Sequential(new List<ILayer>() | Model model = KerasApi.keras.Sequential(new List<ILayer>() | ||||
{ | { | ||||
@@ -116,7 +112,7 @@ public class SequentialModelTest | |||||
keras.layers.Softmax(1) | keras.layers.Softmax(1) | ||||
}); | }); | ||||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||||
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
var num_epochs = 1; | var num_epochs = 1; | ||||
var batch_size = 8; | var batch_size = 8; | ||||
@@ -125,7 +121,7 @@ public class SequentialModelTest | |||||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | ||||
model.save("./pb_alex_sequential", save_format: "tf"); | |||||
model.save("./alexnet_from_sequential", save_format: "tf"); | |||||
// The saved model can be test with the following python code: | // The saved model can be test with the following python code: | ||||
#region alexnet_python_code | #region alexnet_python_code |
@@ -27,4 +27,28 @@ | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<None Update="Assets\simple_model_from_auto_compile\fingerprint.pb"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\keras_metadata.pb"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\saved_model.pb"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\variables\variables.data-00000-of-00001"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\variables\variables.index"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\kernel1.npy"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Assets\simple_model_from_auto_compile\bias0.npy"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
</ItemGroup> | |||||
</Project> | </Project> |