@@ -1,5 +1,5 @@ | |||
namespace Tensorflow.Checkpoint; | |||
public record class CheckpointOptions( | |||
string experimental_io_device = null, | |||
string? experimental_io_device = null, | |||
bool experimental_enable_async_checkpoint = false); |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Serilog.Debugging; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.Checkpoint; | |||
@@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable | |||
return new ObjectGraphView(Root, _attached_dependencies); | |||
} | |||
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||
{ | |||
List<TrackableReference> res = base.children(obj, save_type) | |||
List<TrackableReference> res = base.children(obj, save_type, serialization_cache) | |||
.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||
// Check the reference, not value. | |||
if (obj == Root && _attached_dependencies is not null) | |||
@@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable | |||
return res; | |||
} | |||
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||
{ | |||
return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); | |||
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); | |||
} | |||
public IEnumerable<TrackableReference>? AttachedDependencies | |||
@@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint | |||
); | |||
public static class SaveUtil | |||
{ | |||
public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
{ | |||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
@@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="cache"></param> | |||
/// <param name="object_graph_proto"></param> | |||
private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new(); | |||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
foreach(var td in tensor_trackables) | |||
{ | |||
// TODO: deal with cache. | |||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
var trackable = td.object_to_save; | |||
IDictionary<string, object> tensor_dict; | |||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
{ | |||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
@@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint | |||
return serialized_tensors; | |||
} | |||
private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
var trackable = trackable_data.object_to_save; | |||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
IDictionary<string, object> ret_tensor_dict; | |||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
if (call_with_mapped_captures) | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint | |||
ret_tensor_dict = trackable.serialize_to_tensors(); | |||
} | |||
// TODO: revise the types and complete it | |||
Dictionary<string, object> tensor_dict = new(); | |||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
foreach(var pair in ret_tensor_dict) | |||
{ | |||
var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
@@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint | |||
tensor_dict[checkpoint_key] = maybe_tensor; | |||
if(maybe_tensor is SaveSpec) | |||
if(maybe_tensor.GetValueA() is SaveSpec) | |||
{ | |||
((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
throw new NotImplementedException(); | |||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
} | |||
if(object_graph_proto is not null) | |||
@@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="object_graph_proto"></param> | |||
/// <returns></returns> | |||
private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, string> object_names = new(); | |||
@@ -174,25 +174,20 @@ public static class SaveUtilV1 | |||
{ | |||
var name = factory_data.name; | |||
var key = factory_data.checkpoint_key; | |||
var saveable_factory = factory_data.factory; | |||
var maybe_saveable = factory_data.factory; | |||
// TODO: oneflow python has a process with callable `saveable_factory`. | |||
var maybe_saveable = saveable_factory; | |||
IEnumerable<MySaveableObject> savesbles; | |||
if (maybe_saveable is MySaveableObject) | |||
{ | |||
savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable }; | |||
} | |||
else if (maybe_saveable is Tensor) | |||
List<MySaveableObject> saveables = new(); | |||
if (maybe_saveable.DataType == typeof(MySaveableObject)) | |||
{ | |||
savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); | |||
saveables.Add(maybe_saveable.GetValueB()); | |||
} | |||
else | |||
{ | |||
throw new TypeError("Unexpected type."); | |||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); | |||
} | |||
foreach (var saveable in savesbles) | |||
foreach (var saveable in saveables) | |||
{ | |||
if (!saveable.name.Contains(key)) | |||
{ | |||
@@ -204,11 +199,11 @@ public static class SaveUtilV1 | |||
// skip the process of PythonState | |||
named_saveable_objects.AddRange(savesbles); | |||
named_saveable_objects.AddRange(saveables); | |||
if(!fill_object_proto) continue; | |||
// skip the process of TrackableSaveable | |||
// skip the process of `TrackableSaveable` because of lack of APIs. | |||
object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | |||
@@ -221,7 +216,7 @@ public static class SaveUtilV1 | |||
public record class CheckpointFactoryData | |||
( | |||
object factory, | |||
Maybe<ResourceVariable, MySaveableObject> factory, | |||
string name, | |||
string checkpoint_key | |||
); |
@@ -2,6 +2,7 @@ | |||
using Tensorflow.Train; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
namespace Tensorflow.Checkpoint; | |||
@@ -18,13 +19,13 @@ public class TrackableView | |||
_root_ref = obj; | |||
} | |||
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
obj._maybe_initialize_trackable(); | |||
Dictionary<string, Trackable> children = new(); | |||
// Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||
// Therefore it uses `convert_to_trackable` to have an extra process. | |||
foreach (var pair in obj._trackable_children(save_type)) | |||
foreach (var pair in obj._trackable_children(save_type, cache)) | |||
{ | |||
children[pair.Key] = pair.Value; | |||
} | |||
@@ -33,7 +33,7 @@ public class TrackableSaver | |||
} | |||
private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
{ | |||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
@@ -125,7 +125,7 @@ public class TrackableSaver | |||
} | |||
Dictionary<Tensor, string> feed_dict = new(); | |||
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); | |||
if (checkpoint_number is not null) | |||
{ | |||
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||
@@ -133,6 +133,7 @@ public class TrackableSaver | |||
Tensor file_prefix_tensor; | |||
Tensor object_graph_tensor; | |||
string file_prefix_to_save; | |||
if (use_session) | |||
{ | |||
if (_object_graph_feed_tensor is null) | |||
@@ -145,16 +146,18 @@ public class TrackableSaver | |||
object_graph_tensor = _object_graph_feed_tensor; | |||
file_prefix_tensor = _file_prefix_feed_tensor; | |||
feed_dict[file_prefix_tensor] = file_prefix; | |||
file_prefix_to_save = ""; | |||
} | |||
else | |||
{ | |||
// In python there is `with ops.device("/cpu:0")`. | |||
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | |||
object_graph_tensor = null; | |||
file_prefix_to_save = file_prefix; | |||
} | |||
var (save_path, new_feed_additions) = | |||
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); | |||
if (new_feed_additions is not null) | |||
{ | |||
@@ -6,9 +6,254 @@ using Tensorflow.Train; | |||
using static Tensorflow.ApiDef.Types; | |||
using static Tensorflow.CostGraphDef.Types; | |||
using static Tensorflow.OptimizerOptions.Types; | |||
using static Tensorflow.Binding; | |||
using System.Text.RegularExpressions; | |||
using System.Linq; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Training; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Checkpoint | |||
{ | |||
/// <summary> | |||
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. | |||
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. | |||
/// </summary> | |||
public interface IFunctionHolder | |||
{ | |||
int ArgCount { get; } | |||
object DynamicInvoke(params object[] args); | |||
} | |||
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder | |||
{ | |||
public int ArgCount => 0; | |||
public object DynamicInvoke(params object[] args) | |||
{ | |||
return Func.DynamicInvoke(args); | |||
} | |||
} | |||
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||
{ | |||
public int ArgCount => 1; | |||
public object DynamicInvoke(params object[] args) | |||
{ | |||
return Func.DynamicInvoke(args); | |||
} | |||
} | |||
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder | |||
{ | |||
public int ArgCount => 2; | |||
public object DynamicInvoke(params object[] args) | |||
{ | |||
return Func.DynamicInvoke(args); | |||
} | |||
} | |||
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder | |||
{ | |||
public int ArgCount => 3; | |||
public object DynamicInvoke(params object[] args) | |||
{ | |||
return Func.DynamicInvoke(args); | |||
} | |||
} | |||
public class Maybe<TA, TB> | |||
{ | |||
private TA? _valueA = default(TA); | |||
private TB? _valueB = default(TB); | |||
private Type _type; | |||
private bool _assigned = false; | |||
public Maybe(TA value) | |||
{ | |||
_valueA = value; | |||
_type= typeof(TA); | |||
_assigned = true; | |||
} | |||
public Maybe(TB value) | |||
{ | |||
_valueB = value; | |||
_type = typeof(TB); | |||
_assigned = true; | |||
} | |||
public Type DataType => _type; | |||
public TA GetValueA() | |||
{ | |||
if(!_assigned || DataType != typeof(TA)) | |||
{ | |||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||
} | |||
return _valueA; | |||
} | |||
public TB GetValueB() | |||
{ | |||
if (!_assigned || DataType != typeof(TB)) | |||
{ | |||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||
} | |||
return _valueB; | |||
} | |||
public object GetValue() | |||
{ | |||
if (!_assigned) | |||
{ | |||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||
} | |||
if(DataType == typeof(TA) && _valueA is not null) | |||
{ | |||
return _valueA; | |||
} | |||
else if(DataType == typeof(TB) && _valueB is not null) | |||
{ | |||
return _valueB; | |||
} | |||
else if(DataType == typeof(TA)) | |||
{ | |||
return _valueA; | |||
} | |||
else | |||
{ | |||
return _valueB; | |||
} | |||
} | |||
public static implicit operator Maybe<TA, TB>(TA a) | |||
{ | |||
return new Maybe<TA, TB>(a); | |||
} | |||
public static implicit operator Maybe<TA, TB>(TB b) | |||
{ | |||
return new Maybe<TA, TB>(b); | |||
} | |||
} | |||
internal class SingleDeviceSaver | |||
{ | |||
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict; | |||
} | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict) | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
x => x.Key, x => x.Value.ToDictionary( | |||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
} | |||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||
{ | |||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
x => x.Key, x => x.Value.ToDictionary( | |||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
} | |||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||
{ | |||
if(options is null) | |||
{ | |||
options = new CheckpointOptions(); | |||
} | |||
List<string> tensor_names = new(); | |||
List<Tensor> tensors = new(); | |||
List<string> slice_specs = new(); | |||
foreach(var pair in _tensor_slice_dict) | |||
{ | |||
var checkpoint_key = pair.Key; | |||
var tensor_slices = pair.Value; | |||
foreach(var slice in tensor_slices) | |||
{ | |||
var slice_spec = slice.Key; | |||
var maybe_tensor = slice.Value; | |||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||
{ | |||
var spec = maybe_tensor.GetValueB(); | |||
var tensor_value = spec.tensor; | |||
if (tensor_value is not null) | |||
{ | |||
tensor_names.Add(spec.name); | |||
tensors.Add(tensor_value); | |||
slice_specs.Add(spec.slice_spec); | |||
} | |||
} | |||
else | |||
{ | |||
var tensor = maybe_tensor.GetValueA(); | |||
tensor_names.Add(checkpoint_key); | |||
tensors.Add(tensor); | |||
slice_specs.Add(slice_spec); | |||
} | |||
} | |||
} | |||
// TODO: specify the device. | |||
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); | |||
} | |||
public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); | |||
public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||
{ | |||
if(options is null) | |||
{ | |||
options = new CheckpointOptions(); | |||
} | |||
List<string> tensor_names = new(); | |||
List<TF_DataType> tensor_dtypes = new(); | |||
List<string> slice_specs = new(); | |||
foreach(var pair in _tensor_slice_dict) | |||
{ | |||
var checkpoint_key = pair.Key; | |||
var tensor_slices = pair.Value; | |||
foreach(var slice in tensor_slices) | |||
{ | |||
var slice_spec = slice.Key; | |||
var maybe_tensor = slice.Value; | |||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||
{ | |||
var spec = maybe_tensor.GetValueB(); | |||
tensor_dtypes.Add(spec.dtype); | |||
slice_specs.Add(spec.slice_spec); | |||
tensor_names.Add(spec.name); | |||
} | |||
else | |||
{ | |||
var tensor = maybe_tensor.GetValueA(); | |||
tensor_dtypes.Add(tensor.dtype); | |||
slice_specs.Add(slice_spec); | |||
tensor_names.Add(checkpoint_key); | |||
} | |||
} | |||
} | |||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||
// tf python has code `with ops.device(restore_device):` here. | |||
tf.device(restore_device); // may be risky. | |||
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
int idx = 0; | |||
foreach(var pair in _tensor_slice_dict) | |||
{ | |||
var checkpoint_key = pair.Key; | |||
var tensor_slices = pair.Value; | |||
foreach(var slice_spec in tensor_slices.Keys) | |||
{ | |||
var restored_tensor = restored_tensors[idx++]; | |||
if (!restored_tensor_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
} | |||
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; | |||
} | |||
} | |||
return restored_tensor_dict; | |||
} | |||
public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); | |||
} | |||
/// <summary> | |||
/// Saves checkpoints directly from multiple devices. | |||
/// Note that this is a low-level utility which stores Tensors in the keys | |||
@@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint | |||
/// </summary> | |||
public class MultiDeviceSaver | |||
{ | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors, | |||
private Dictionary<string, SingleDeviceSaver> _single_device_savers; | |||
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers; | |||
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; | |||
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys; | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
/// <param name="registered_savers"></param> | |||
/// <param name="call_with_mapped_capture"></param> | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
{ | |||
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); | |||
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>(); | |||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
foreach(var pair in serialized_tensors) | |||
{ | |||
var obj = pair.Key; | |||
var tensor_dict = pair.Value; | |||
IFunctionHolder restore_fn; | |||
if(obj is null) | |||
{ | |||
restore_fn = new FunctionHolder<object?>(() => null); | |||
} | |||
else | |||
{ | |||
restore_fn = null; | |||
// TODO: implement obj._restore_from_tensors | |||
} | |||
foreach(var item in tensor_dict) | |||
{ | |||
var checkpoint_key = item.Key; | |||
IDictionary<string, Tensor> spec_to_tensor; | |||
if(item.Value.DataType != typeof(IDictionary<string, Tensor>)) | |||
{ | |||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||
spec_to_tensor[""] = item.Value.GetValueA(); | |||
} | |||
else | |||
{ | |||
spec_to_tensor = item.Value.GetValueB(); | |||
} | |||
foreach(var spec in spec_to_tensor) | |||
{ | |||
var slice_spec = spec.Key; | |||
var tensor = spec.Value; | |||
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) | |||
{ | |||
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + | |||
$"slice spec. This is invalid because one will overwrite the " + | |||
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); | |||
} | |||
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
// skip the process of device name because lack of API. | |||
var host_device = tensor.Device; | |||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
} | |||
internal_dict[checkpoint_key][slice_spec] = tensor; | |||
} | |||
} | |||
} | |||
_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | |||
_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>(); | |||
if(registered_savers is not null && registered_savers.Count > 0) | |||
{ | |||
// TODO: complete the implementation. | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
public Operation? save(string file_prefix, CheckpointOptions? options= null) | |||
public Operation save(string file_prefix, CheckpointOptions? options= null) | |||
{ | |||
throw new NotImplementedException(); | |||
if(options is null) | |||
{ | |||
options = new CheckpointOptions(); | |||
} | |||
tf.device("CPU"); // may be risky. | |||
// TODO: optimize the implementation with new APIs adding to `string_ops`. | |||
string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; | |||
var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); | |||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
Operation save_fn() | |||
{ | |||
List<Tensor> saved_prefixes= new(); | |||
foreach(var saver in _registered_savers) | |||
{ | |||
// TODO: implementi it later. | |||
throw new NotImplementedException(); | |||
} | |||
int num_shards = _single_device_savers.Count; | |||
List<Operation> sharded_saves = new(); | |||
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); | |||
string? last_device = null; | |||
int shard = 0; | |||
foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) | |||
{ | |||
var device = pair.Key; | |||
var saver = pair.Value; | |||
last_device = device; | |||
// skip the extra process of device name because of lack of API. | |||
tf.device(device); | |||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
saved_prefixes.Add(shard_prefix); | |||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||
} | |||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||
{ | |||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||
tf.device(merge_device); | |||
return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); | |||
} | |||
} | |||
if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) | |||
{ | |||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
return save_fn(); | |||
} | |||
} | |||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||
public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); | |||
public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null) | |||
{ | |||
if(options is null) | |||
{ | |||
options = new CheckpointOptions(); | |||
} | |||
IDictionary<string, Operation> restore_func() | |||
{ | |||
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
Dictionary<string, Operation> restore_ops = new(); | |||
foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | |||
{ | |||
var device = single_saver.Key; | |||
var saver = single_saver.Value; | |||
tf.device(device); | |||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||
foreach(var pair in restored_tensor_dict) | |||
{ | |||
var checkpoint_key = pair.Key; | |||
var slice_and_tensor = pair.Value; | |||
foreach(var item in slice_and_tensor) | |||
{ | |||
var slice_spec = item.Key; | |||
var tensor = item.Value; | |||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
Dictionary<string, Tensor> dict = new(); | |||
dict[slice_spec] = tensor; | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; | |||
} | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
} | |||
restore_fn_input_count[restore_fn]--; | |||
if (restore_fn_input_count[restore_fn] == 0) | |||
{ | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
foreach(var input in restore_fn_inputs[restore_fn]) | |||
{ | |||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
} | |||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
if(ret is IDictionary<string, Operation>) | |||
{ | |||
var dict = (IDictionary<string, Operation>)ret; | |||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
foreach(var item in _registered_savers) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return restore_ops; | |||
} | |||
// TODO: complete the implementation. Currently skip it because of lack of API. | |||
bool has_custom_device_saver = false; | |||
if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) | |||
{ | |||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
return restore_func(); | |||
} | |||
} | |||
/// <summary> | |||
/// Serializes to a SaverDef referencing the current graph. | |||
/// </summary> | |||
public SaverDef to_proto() | |||
{ | |||
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); | |||
var save_tensor = _traced_save(filename_tensor); | |||
var restore_op = _traced_restore(filename_tensor).op; | |||
return new SaverDef() | |||
{ | |||
FilenameTensorName = filename_tensor.name, | |||
SaveTensorName = save_tensor.name, | |||
RestoreOpName = restore_op.name, | |||
Version = SaverDef.Types.CheckpointFormatVersion.V2 | |||
}; | |||
} | |||
[AutoGraph] | |||
private Tensor _traced_save(Tensor file_prefix) | |||
{ | |||
var save_op = save(file_prefix.StringData()[0]); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(new object[]{ save_op })) | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
} | |||
[AutoGraph] | |||
private Tensor _traced_restore(Tensor file_prefix) | |||
{ | |||
var restore_op = restore(file_prefix.StringData()[0]); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(new object[] { restore_op })) | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
} | |||
private static Tensor registered_saver_filename(string filename, string saver_name) | |||
{ | |||
return tf.constant($"{filename}-{saver_name}"); | |||
} | |||
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) | |||
{ | |||
throw new NotImplementedException(); | |||
return filename_tensor; | |||
} | |||
} | |||
} |
@@ -0,0 +1,35 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
public interface ISerializedAttributes | |||
{ | |||
IDictionary<string, Trackable> Functions { get; } | |||
IDictionary<string, Trackable> CheckpointableObjects { get; } | |||
/// <summary> | |||
/// Returns functions to attach to the root object during serialization. | |||
/// </summary> | |||
IDictionary<string, Trackable> FunctionsToSerialize { get; } | |||
/// <summary> | |||
/// Returns objects to attach to the root object during serialization. | |||
/// </summary> | |||
IDictionary<string, Trackable> ObjectsToSerialize{get; } | |||
/// <summary> | |||
/// Saves function dictionary, and validates dictionary values. | |||
/// </summary> | |||
/// <param name="function_dict"></param> | |||
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict); | |||
/// <summary> | |||
/// Saves objects to a dictionary, and validates the values. | |||
/// </summary> | |||
/// <param name="object_dict"></param> | |||
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict); | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Operations.Activation; | |||
using static Tensorflow.Binding; | |||
@@ -24,7 +25,7 @@ namespace Tensorflow.Train | |||
} | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
if(save_type != SaveType.SAVEDMODEL) | |||
{ | |||
@@ -4,57 +4,130 @@ using Tensorflow.Train; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
namespace Tensorflow; | |||
public class AugmentedGraphView: ObjectGraphView | |||
{ | |||
// private object _children_cache; | |||
// private object _serialization_cache; | |||
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache; | |||
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache; | |||
private List<string> _untraces_functions; | |||
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions; | |||
public AugmentedGraphView(Trackable root): base(root) | |||
{ | |||
_untraces_functions = new(); | |||
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>(); | |||
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>(); | |||
_untraces_functions = new List<string>(); | |||
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>(); | |||
} | |||
public void set_signature(object signature_map, object wrapped_functions) | |||
public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions) | |||
{ | |||
// TODO: cache | |||
list_children(Root); | |||
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; | |||
if (!_children_cache.ContainsKey(Root)) | |||
{ | |||
_children_cache[Root] = new Dictionary<string, Trackable>(); | |||
} | |||
_children_cache[Root][name] = signature_map; | |||
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); | |||
} | |||
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) | |||
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||
{ | |||
Dictionary<string, Trackable> children = new(); | |||
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) | |||
if(serialization_cache is not null) | |||
{ | |||
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); | |||
} | |||
if (!_children_cache.ContainsKey(obj)) | |||
{ | |||
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>(); | |||
_children_cache[obj] = children; | |||
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) | |||
{ | |||
var name = pair.Name; | |||
var child = pair.Refer; | |||
if(child is ConcreteFunction) | |||
{ | |||
child = maybe_uncache_variable_captures((ConcreteFunction)child); | |||
} | |||
children[name] = child; | |||
} | |||
if (obj is Function && children.Count == 0) | |||
{ | |||
_untraces_functions.Add(((Function)obj).Name); | |||
} | |||
} | |||
List<TrackableReference> res = new(); | |||
foreach(var pair in _children_cache[obj]) | |||
{ | |||
var name = pair.Name; | |||
var child = pair.Refer; | |||
children[name] = child; | |||
res.Add(new TrackableReference(pair.Key, pair.Value)); | |||
} | |||
if (obj is Function && children.Count == 0) | |||
return res; | |||
} | |||
private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) | |||
{ | |||
if (_wrapped_functions.ContainsKey(concrete_function)) | |||
{ | |||
_untraces_functions.Add(((Function)obj).Name); | |||
return _wrapped_functions[concrete_function]; | |||
} | |||
// skip the process here because of lack of feature. | |||
// In the future, we may add an attribute which could specify if the variable is supposed to be cached. | |||
//foreach(var capture in concrete_function.CapturedInputs) | |||
//{ | |||
return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||
//} | |||
return concrete_function; | |||
} | |||
public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
{ | |||
// TODO: implement it if needed. | |||
Trackable get_merged_trackable(Trackable x) | |||
{ | |||
// TODO: complete it with new definitions `Asset` and `TrackableConstant`. | |||
return x; | |||
} | |||
var trackable_objects = base.breadth_first_traversal(); | |||
foreach(var obj in _children_cache.Keys) | |||
{ | |||
// skip the deletion of cache (maybe do it later). | |||
foreach(var pair in _children_cache[obj]) | |||
{ | |||
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); | |||
} | |||
} | |||
return base.breadth_first_traversal(); | |||
} | |||
public List<(string, Trackable)> list_dependencies(Trackable obj) | |||
{ | |||
// TODO: deal with cache. | |||
return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); | |||
IDictionary<string, Trackable> children; | |||
if (!_children_cache.ContainsKey(obj)) | |||
{ | |||
children= new Dictionary<string, Trackable>(); | |||
} | |||
else | |||
{ | |||
children= _children_cache[obj]; | |||
} | |||
List<(string, Trackable)> res = new(); | |||
foreach(var pair in obj.deserialization_dependencies(children)) | |||
{ | |||
res.Add((pair.Key, pair.Value)); | |||
} | |||
return res; | |||
} | |||
public Trackable get_child(Trackable obj, string name) | |||
{ | |||
throw new NotImplementedException(); | |||
return _children_cache[obj][name]; | |||
} | |||
} |
@@ -141,16 +141,16 @@ public class SaveableView | |||
foreach (var node in _nodes) | |||
{ | |||
var node_id = _node_ids[node]; | |||
List<int> deps = new(); | |||
List<int> deps = new List<int>(); | |||
dependency_map.Add(node_id, deps); | |||
// TODO: deal with captured tensor. | |||
string node_path; | |||
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | |||
{ | |||
if (!_node_ids.ContainsKey(dep)) | |||
{ | |||
node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||
throw new ValueError( | |||
$"Found an untracked dependency. Object {node_path} depends on {dep}, " + | |||
$"but this dependency isn't listed as a child. Please track this child by " + | |||
@@ -24,7 +24,7 @@ public static partial class SavedModelUtils | |||
}.Select(x => (int)x); | |||
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | |||
string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||
{ | |||
if (options is null) | |||
{ | |||
@@ -41,9 +41,9 @@ public static partial class SavedModelUtils | |||
if (!experimental_skip_checkpoint) | |||
{ | |||
Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); | |||
SavedModelUtils.get_or_create_variables_dir(export_dir); | |||
CheckpointOptions ckpt_options = new(options.experimental_io_device); | |||
object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||
object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||
} | |||
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | |||
@@ -67,7 +67,7 @@ public static partial class SavedModelUtils | |||
} | |||
var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | |||
File.WriteAllText(path, saved_model.ToString()); | |||
File.WriteAllBytes(path, saved_model.ToByteArray()); | |||
if (options.save_debug_info) | |||
{ | |||
@@ -81,7 +81,7 @@ public static partial class SavedModelUtils | |||
private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||
IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||
{ | |||
if (ops.inside_function()) | |||
{ | |||
@@ -95,9 +95,9 @@ public static partial class SavedModelUtils | |||
} | |||
AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||
if (signatures is not null) | |||
if (signatures is null) | |||
{ | |||
throw new NotImplementedException(); | |||
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); | |||
} | |||
// TODO: process of aignatures and wrapped_functions | |||
@@ -125,7 +125,7 @@ public static partial class SavedModelUtils | |||
} | |||
private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | |||
IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist, | |||
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist, | |||
bool save_custom_gradients) | |||
{ | |||
var resource_initializers = saveable_view.get_concrete_resource_initializers(); | |||
@@ -1,15 +1,84 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Train; | |||
namespace Tensorflow; | |||
public static class SignatureSerializationUtils | |||
{ | |||
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; | |||
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; | |||
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; | |||
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures) | |||
{ | |||
var signature_map = new SignatureMap(); | |||
foreach (var pair in signatures) | |||
{ | |||
var name = pair.Key; | |||
var func = pair.Value; | |||
Debug.Assert(func is ConcreteFunction); | |||
// TODO: assert the `func.structured_outputs` and arg_keywords. | |||
signature_map._add_signature(name, (ConcreteFunction)func); | |||
} | |||
return signature_map; | |||
} | |||
public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) | |||
{ | |||
var children = graph_view.list_children(graph_view.Root); | |||
List<Trackable> possible_signatures = new(); | |||
foreach (var item in children) | |||
{ | |||
var name = item.Name; | |||
var child = item.Refer; | |||
if(child is not (Function or ConcreteFunction)) | |||
{ | |||
continue; | |||
} | |||
if(name == DEFAULT_SIGNATURE_ATTR) | |||
{ | |||
Debug.Assert(child is ConcreteFunction); | |||
return (ConcreteFunction)child; | |||
} | |||
ConcreteFunction concrete = get_signature(child); | |||
if(concrete is not null && valid_signature(concrete)) | |||
{ | |||
possible_signatures.Add(concrete); | |||
} | |||
} | |||
if(possible_signatures.Count == 1) | |||
{ | |||
var signature = get_signature(possible_signatures[0]); | |||
if(signature is not null && valid_signature(signature)) | |||
{ | |||
return signature; | |||
} | |||
} | |||
return null; | |||
} | |||
private static ConcreteFunction get_signature(Trackable function) | |||
{ | |||
// TODO: implement it. | |||
return null; | |||
} | |||
private static bool valid_signature(ConcreteFunction concreate_function) | |||
{ | |||
// TODO: implement it. | |||
return false; | |||
} | |||
} | |||
public class SignatureMap: Trackable | |||
{ | |||
private Dictionary<string, Function> _signatures; | |||
private Dictionary<string, ConcreteFunction> _concrete_signatures; | |||
private Dictionary<string, Trackable> _signatures; | |||
public SignatureMap() | |||
{ | |||
@@ -18,7 +87,7 @@ public class SignatureMap: Trackable | |||
public void _add_signature(string name, ConcreteFunction concrete_function) | |||
{ | |||
_concrete_signatures[name] = concrete_function; | |||
_signatures[name] = concrete_function; | |||
} | |||
public void _add_signature(string name, Function concrete_function) | |||
@@ -26,33 +95,13 @@ public class SignatureMap: Trackable | |||
_signatures[name] = concrete_function; | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
if (save_type != SaveType.SAVEDMODEL) | |||
{ | |||
return new Dictionary<string, Trackable>(); | |||
} | |||
Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); | |||
foreach (var pair in _concrete_signatures) | |||
{ | |||
res[pair.Key] = pair.Value; | |||
} | |||
return res; | |||
} | |||
public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures) | |||
{ | |||
var signature_map = new SignatureMap(); | |||
foreach (var pair in signatures) | |||
{ | |||
var name = pair.Key; | |||
var func = pair.Value; | |||
// TODO: assert the arg_keywords | |||
signature_map._add_signature(name, func); | |||
} | |||
return signature_map; | |||
return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||
} | |||
} |
@@ -16,18 +16,38 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public static class saveable_object_util | |||
/// <summary> | |||
/// A SaveableObject that defines `Trackable` checkpointing steps. | |||
/// </summary> | |||
public class TrackableSaveable : MySaveableObject | |||
{ | |||
public class TrackableSaveable: MySaveableObject | |||
private string _prefix; | |||
private IEnumerable<string> _local_names; | |||
private Trackable _trackable; | |||
private bool _call_with_mapped_captures; | |||
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. | |||
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names, | |||
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) | |||
{ | |||
_prefix = prefix; | |||
_trackable = obj; | |||
_local_names = local_names; | |||
_call_with_mapped_captures = call_with_mapped_captures; | |||
} | |||
// TODO: complete this class. | |||
} | |||
public static class saveable_object_util | |||
{ | |||
/// <summary> | |||
/// Returns the variables and names that will be used for a Saver. | |||
/// </summary> | |||
@@ -57,7 +77,7 @@ namespace Tensorflow | |||
} | |||
/// <summary> | |||
/// Create `SaveableObject`s from an operation. | |||
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="name"></param> | |||
@@ -79,6 +99,74 @@ namespace Tensorflow | |||
} | |||
} | |||
/// <summary> | |||
/// Create `SaveableObject`s from an operation. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name) | |||
{ | |||
// The `op` maybe `Variable` or `Trackable`. | |||
if (obj is BaseResourceVariable) | |||
{ | |||
var variable = obj as BaseResourceVariable; | |||
if (variable.InGraphMode) | |||
{ | |||
yield return new ResourceVariableSaveable(variable.GraphElement, "", name); | |||
} | |||
else | |||
{ | |||
Debug.Assert(variable is ResourceVariable); | |||
yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); | |||
} | |||
} | |||
else | |||
{ | |||
foreach(var pair in saveable_objects_from_trackable(obj)) | |||
{ | |||
var attr = pair.Key; | |||
var factory = pair.Value; | |||
string full_name; | |||
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) | |||
{ | |||
full_name = name; | |||
} | |||
else | |||
{ | |||
full_name = name + "_" + attr; | |||
} | |||
if(factory.DataType == typeof(ResourceVariable)) | |||
{ | |||
var variable = factory.GetValueA(); | |||
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
{ | |||
yield return op; | |||
} | |||
} | |||
else | |||
{ | |||
var variable = factory.GetValueB(); | |||
foreach (var op in saveable_objects_for_op(variable, variable.name)) | |||
{ | |||
yield return op; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// Create `SaveableObject`s from an operation. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name) | |||
{ | |||
yield return obj; | |||
} | |||
public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | |||
{ | |||
op_list = op_list.OrderBy(x => x.Name).ToArray(); | |||
@@ -127,16 +215,55 @@ namespace Tensorflow | |||
return names_to_saveables; | |||
} | |||
public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj) | |||
public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
{ | |||
// TODO: complete the implementation. | |||
return obj.gather_saveables_for_checkpoint(); | |||
// skip the process of type `PythonState` | |||
if (trackable_has_serialize_to_tensor(obj)) | |||
{ | |||
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||
var tensor_dict = obj.serialize_to_tensors(); | |||
List<SaveSpec> specs = new(); | |||
List<string> local_names = new(); | |||
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||
foreach(var pair in tensor_dict) | |||
{ | |||
var tensor_name = pair.Key; | |||
var maybe_tensor = pair.Value; | |||
local_names.Add(tensor_name); | |||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
IDictionary<string, Tensor> internal_dict; | |||
if(maybe_tensor.DataType == typeof(Tensor)) | |||
{ | |||
internal_dict= new Dictionary<string, Tensor>(); | |||
internal_dict[""] = maybe_tensor.GetValueA(); | |||
} | |||
else | |||
{ | |||
internal_dict = maybe_tensor.GetValueB(); | |||
} | |||
foreach(var item in internal_dict) | |||
{ | |||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
} | |||
} | |||
Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new(); | |||
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
return res; | |||
} | |||
else | |||
{ | |||
return obj.gather_saveables_for_checkpoint(); | |||
} | |||
} | |||
public static bool trackable_has_serialize_to_tensor(Trackable obj) | |||
{ | |||
// TODO: implement it. | |||
return false; | |||
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); | |||
} | |||
internal static string convert_to_string(string x) | |||
@@ -158,27 +285,28 @@ namespace Tensorflow | |||
public Trackable Obj => _obj; | |||
public IList<MySaveableObject> mySaveables=> _saveables; | |||
public override IDictionary<string, object> serialize_to_tensors() | |||
public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
{ | |||
return saveable_objects_to_tensor_dict(_saveables); | |||
return saveable_object_to_tensor_dict(_saveables); | |||
} | |||
/// <summary> | |||
/// Converts a list of SaveableObjects to a tensor dictionary. | |||
/// </summary> | |||
/// <param name="saveables"></param> | |||
public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables) | |||
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
{ | |||
Dictionary<string, object> tensor_dict = new(); | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
foreach (var saveable in saveables) | |||
{ | |||
foreach(var spec in saveable.specs) | |||
{ | |||
// skip the check that if `spec` is callable. | |||
var name = saveable_object_util.convert_to_string(spec.name); | |||
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
throw new NotImplementedException(); | |||
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||
} | |||
else | |||
{ | |||
@@ -16,7 +16,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.ModelSaving; | |||
using Tensorflow.Training; | |||
using static Tensorflow.Binding; | |||
@@ -39,8 +42,8 @@ namespace Tensorflow.Train | |||
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | |||
new Dictionary<string, ResourceVariable>(); | |||
protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||
private bool _manual_tracking = true; | |||
private static Trackable _none = new Function(); | |||
@@ -94,9 +97,13 @@ namespace Tensorflow.Train | |||
// assign again. It will add this variable to our dependencies, and if there | |||
// is a non-trivial restoration queued, it will handle that. This also | |||
// handles slot variables. | |||
if (!args.Overwrite || new_variable is RefVariable) | |||
return _track_checkpointable(new_variable, name: args.Name, | |||
overwrite: args.Overwrite); | |||
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | |||
{ | |||
var temp = new_variable as Trackable; | |||
var res = _track_trackable(temp, args.Name, args.Overwrite); | |||
Debug.Assert(res is IVariableV1); | |||
return res as IVariableV1; | |||
} | |||
else | |||
return new_variable; | |||
} | |||
@@ -122,13 +129,16 @@ namespace Tensorflow.Train | |||
/// </summary> | |||
public void _maybe_initialize_trackable() | |||
{ | |||
if(_unconditional_checkpoint_dependencies is not null) | |||
{ | |||
return; | |||
} | |||
_self_update_uid = -1; | |||
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||
_unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||
} | |||
// TODO: cache | |||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||
{ | |||
_maybe_initialize_trackable(); | |||
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||
@@ -139,8 +149,8 @@ namespace Tensorflow.Train | |||
_maybe_initialize_trackable(); | |||
if (!_manual_tracking) return trackable; | |||
var new_reference = new TrackableReference(name, trackable); | |||
var current_object = _lookupup_dependency(name); | |||
var current_object = _lookup_dependency(name); | |||
if(current_object is null) | |||
{ | |||
_unconditional_checkpoint_dependencies.Add(new_reference); | |||
@@ -170,7 +180,7 @@ namespace Tensorflow.Train | |||
// TODO: complete the implementation. | |||
} | |||
public virtual Trackable? _lookupup_dependency(string name) | |||
public virtual Trackable? _lookup_dependency(string name) | |||
{ | |||
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; | |||
else return null; | |||
@@ -199,8 +209,8 @@ namespace Tensorflow.Train | |||
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | |||
} | |||
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null, | |||
IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null) | |||
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map, | |||
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null) | |||
{ | |||
var (self_object_map, self_tensor_map) = map_resources(options); | |||
foreach (var pair in self_object_map) | |||
@@ -215,9 +225,17 @@ namespace Tensorflow.Train | |||
return self_tensor_map.Keys.ToList(); | |||
} | |||
public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint() | |||
public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
{ | |||
return _self_saveable_object_factories; | |||
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
{ | |||
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
return _self_saveable_object_factories; | |||
} | |||
} | |||
/// <summary> | |||
@@ -229,7 +247,7 @@ namespace Tensorflow.Train | |||
/// </summary> | |||
/// <returns></returns> | |||
/// <exception cref="NotImplementedException"></exception> | |||
public virtual IDictionary<string, object> serialize_to_tensors() | |||
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -1,4 +1,5 @@ | |||
using System.Collections.Generic; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Exceptions; | |||
using Tensorflow.Train; | |||
@@ -22,7 +23,7 @@ public static class TrackableUtils | |||
private static string _ESCAPE_CHAR = "."; | |||
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||
{ | |||
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | |||
@@ -145,4 +146,27 @@ public static class TrackableUtils | |||
return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | |||
} | |||
} | |||
/// <summary> | |||
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. | |||
/// </summary> | |||
/// <param name="key"></param> | |||
/// <param name="prefix"></param> | |||
/// <returns></returns> | |||
public static string extract_local_name(string key, string? prefix = null) | |||
{ | |||
if(prefix is null) | |||
{ | |||
prefix = ""; | |||
} | |||
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; | |||
try | |||
{ | |||
return key.Substring(key.IndexOf(search_key) + search_key.Length); | |||
} | |||
catch(ArgumentOutOfRangeException) | |||
{ | |||
return key; | |||
} | |||
} | |||
} |
@@ -9,6 +9,7 @@ using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Keras; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Operations.Activation; | |||
using Tensorflow.Train; | |||
using static Tensorflow.ApiDef.Types; | |||
@@ -243,7 +244,7 @@ namespace Tensorflow.Training | |||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
check_external_modification(); | |||
if (_non_append_mutation_value) | |||
@@ -4,6 +4,8 @@ using Tensorflow.Eager; | |||
using Tensorflow.Variables; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
using System.Collections.Generic; | |||
using Tensorflow.ModelSaving; | |||
namespace Tensorflow | |||
{ | |||
@@ -20,6 +22,7 @@ namespace Tensorflow | |||
public string UniqueId => _unique_id; | |||
protected bool _in_graph_mode; | |||
internal bool InGraphMode => _in_graph_mode; | |||
protected bool _trainable; | |||
public bool Trainable => _trainable; | |||
@@ -17,7 +17,9 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -235,5 +237,12 @@ namespace Tensorflow | |||
{ | |||
return _graph_element.eval(session); | |||
} | |||
public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
{ | |||
var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
return res; | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
@@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine | |||
return output_tensors; | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | |||
.ToDictionary(x => x.Key, x => x.Value); | |||
@@ -1,4 +1,5 @@ | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Train; | |||
@@ -9,16 +10,16 @@ public abstract partial class Layer | |||
{ | |||
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||
public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
IDictionary<string, Trackable> children; | |||
if (save_type == SaveType.SAVEDMODEL) | |||
{ | |||
// TODO: deal with cache. | |||
Debug.Assert(cache is not null); | |||
children = TrackableSavedModelSaver.trackable_children(cache); | |||
} | |||
else | |||
@@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine | |||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
public CallContext CallContext => callContext.Value; | |||
public Tensor[] input => inboundNodes[0].input_tensors; | |||
public Tensor[] input | |||
{ | |||
get | |||
{ | |||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||
{ | |||
return inboundNodes[0].input_tensors; | |||
} | |||
return null; | |||
} | |||
} | |||
public Dictionary<int, List<INode>> NodesByDepth { get; set; } | |||
public Shape OutputShape => inboundNodes[0].Outputs.shape; | |||
public Shape OutputShape | |||
{ | |||
get | |||
{ | |||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||
{ | |||
return inboundNodes[0].Outputs.shape; | |||
} | |||
return null; | |||
} | |||
} | |||
protected List<ILayer> _self_tracked_trackables; | |||
public Layer(LayerArgs args) | |||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
bool include_optimizer = true, | |||
string save_format = "tf", | |||
SaveOptions? options = null, | |||
IDictionary<string, ConcreteFunction>? signatures = null, | |||
ConcreteFunction? signatures = null, | |||
bool save_traces = true) | |||
{ | |||
if (save_format != "pb") | |||
@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
@@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
if(save_type == SaveType.SAVEDMODEL) | |||
{ | |||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
public partial class KerasSavedModelUtils | |||
{ | |||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures, | |||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
SaveOptions? options, bool save_traces = true) | |||
{ | |||
if (!overwrite && File.Exists(filepath)) | |||
@@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils | |||
} | |||
var metadata = generate_keras_metadata(saved_nodes, node_paths); | |||
using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, | |||
FileAccess.Write)) | |||
{ | |||
var writer = new StreamWriter(f); | |||
writer.Write(metadata.ToString()); | |||
} | |||
File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); | |||
if (!include_optimizer) | |||
{ | |||
@@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils | |||
/// <param name="layer"></param> | |||
/// <param name="serialization_cache"></param> | |||
/// <returns></returns> | |||
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache) | |||
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||
@@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils | |||
/// <param name="layer"></param> | |||
/// <param name="serialization_cache"></param> | |||
/// <returns></returns> | |||
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache) | |||
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
// TODO: deal with type `RevivedLayer` and `Sequential`. | |||
@@ -18,12 +18,12 @@ public abstract class SavedModelSaver | |||
public abstract string TrackingMetadata { get; } | |||
public abstract IDictionary<string, Trackable> objects_to_serialize( | |||
IDictionary<string, object> serialization_cache); | |||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache); | |||
public abstract IDictionary<string, Trackable> functions_to_serialize( | |||
IDictionary<string, object> serialization_cache); | |||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache); | |||
public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | |||
public IDictionary<string, Trackable> trackable_children(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
if (!KerasSavedModelUtils.ShouldHaveTraces) | |||
{ | |||
@@ -31,7 +31,6 @@ public abstract class SavedModelSaver | |||
} | |||
var children = objects_to_serialize(serialization_cache); | |||
return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||
.ToDictionary(x => x.Key, x => x.Value); | |||
} | |||
@@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
get => Constants.LAYER_IDENTIFIER; | |||
} | |||
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
return get_serialized_attributes(serialization_cache).ObjectsToSerialize; | |||
} | |||
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
return get_serialized_attributes(serialization_cache).FunctionsToSerialize; | |||
} | |||
@@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
/// Generates or retrieves serialized attributes from cache. | |||
/// </summary> | |||
/// <param name="serialization_cache"></param> | |||
protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||
protected ISerializedAttributes get_serialized_attributes(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
// TODO: deal with cache. | |||
IDictionary<Trackable, ISerializedAttributes> keras_cache; | |||
if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) | |||
{ | |||
keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; | |||
} | |||
else | |||
{ | |||
serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary<Trackable, ISerializedAttributes>(); | |||
} | |||
if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; | |||
var serialized_attr = SerializedAttributes.Create(_obj); | |||
var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); | |||
// TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. | |||
if (KerasSavedModelUtils.should_skip_serialization(_obj)) | |||
@@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
/// Returns dictionary of serialized attributes. | |||
/// </summary> | |||
/// <param name="serialization_cache"></param> | |||
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache) | |||
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
{ | |||
var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | |||
var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | |||
@@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
metadata["trainable"] = _obj.Trainable; | |||
// metadata["expects_training_arg"] = _obj._expects_training_arg; | |||
// metadata["dtype"] = policy.serialize(_obj._dtype_policy) | |||
metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); | |||
metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); | |||
// metadata["stateful"] = _obj.stateful; | |||
// metadata["must_restore_from_config"] = _obj.must_restore_from_config; | |||
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | |||
@@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
} | |||
} | |||
public static LayerConfig get_serialized(Layer obj) | |||
public static IDictionary<string, object> get_serialized(Layer obj) | |||
{ | |||
return generic_utils.serialize_keras_object(obj); | |||
// TODO: complete the implmentation (need to revise `get_config`). | |||
return new Dictionary<string, object>(); | |||
//return generic_utils.serialize_keras_object(obj); | |||
} | |||
} |
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
/// <summary> | |||
/// Class that tracks and validates all serialization attributes. | |||
/// </summary> | |||
public abstract class SerializedAttributes | |||
public abstract class SerializedAttributes: ISerializedAttributes | |||
{ | |||
protected IDictionary<string, Trackable?> _object_dict; | |||
protected IDictionary<string, Trackable?> _function_dict; | |||
@@ -50,11 +50,11 @@ public class SaveTest | |||
{ | |||
TrainDir = "mnist", | |||
OneHot = false, | |||
ValidationSize = 50000, | |||
ValidationSize = 0, | |||
}).Result; | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
model.save("", save_format:"pb"); | |||
model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); | |||
} | |||
} |