@@ -1,5 +1,5 @@ | |||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
public record class CheckpointOptions( | public record class CheckpointOptions( | ||||
string experimental_io_device = null, | |||||
string? experimental_io_device = null, | |||||
bool experimental_enable_async_checkpoint = false); | bool experimental_enable_async_checkpoint = false); |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Serilog.Debugging; | using Serilog.Debugging; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
@@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable | |||||
return new ObjectGraphView(Root, _attached_dependencies); | return new ObjectGraphView(Root, _attached_dependencies); | ||||
} | } | ||||
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | { | ||||
List<TrackableReference> res = base.children(obj, save_type) | |||||
List<TrackableReference> res = base.children(obj, save_type, serialization_cache) | |||||
.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | ||||
// Check the reference, not value. | // Check the reference, not value. | ||||
if (obj == Root && _attached_dependencies is not null) | if (obj == Root && _attached_dependencies is not null) | ||||
@@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable | |||||
return res; | return res; | ||||
} | } | ||||
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | { | ||||
return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); | |||||
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); | |||||
} | } | ||||
public IEnumerable<TrackableReference>? AttachedDependencies | public IEnumerable<TrackableReference>? AttachedDependencies | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint | |||||
); | ); | ||||
public static class SaveUtil | public static class SaveUtil | ||||
{ | { | ||||
public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | ||||
{ | { | ||||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | ||||
@@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint | |||||
/// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
/// <param name="cache"></param> | /// <param name="cache"></param> | ||||
/// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | ||||
{ | { | ||||
Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new(); | |||||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
foreach(var td in tensor_trackables) | foreach(var td in tensor_trackables) | ||||
{ | { | ||||
// TODO: deal with cache. | // TODO: deal with cache. | ||||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | ||||
var trackable = td.object_to_save; | var trackable = td.object_to_save; | ||||
IDictionary<string, object> tensor_dict; | |||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | ||||
{ | { | ||||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | ||||
@@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint | |||||
return serialized_tensors; | return serialized_tensors; | ||||
} | } | ||||
private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
{ | { | ||||
var trackable = trackable_data.object_to_save; | var trackable = trackable_data.object_to_save; | ||||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | ||||
IDictionary<string, object> ret_tensor_dict; | |||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
if (call_with_mapped_captures) | if (call_with_mapped_captures) | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint | |||||
ret_tensor_dict = trackable.serialize_to_tensors(); | ret_tensor_dict = trackable.serialize_to_tensors(); | ||||
} | } | ||||
// TODO: revise the types and complete it | |||||
Dictionary<string, object> tensor_dict = new(); | |||||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
foreach(var pair in ret_tensor_dict) | foreach(var pair in ret_tensor_dict) | ||||
{ | { | ||||
var local_name = TrackableUtils.escape_local_name(pair.Key); | var local_name = TrackableUtils.escape_local_name(pair.Key); | ||||
@@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint | |||||
tensor_dict[checkpoint_key] = maybe_tensor; | tensor_dict[checkpoint_key] = maybe_tensor; | ||||
if(maybe_tensor is SaveSpec) | |||||
if(maybe_tensor.GetValueA() is SaveSpec) | |||||
{ | { | ||||
((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
throw new NotImplementedException(); | |||||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
} | } | ||||
if(object_graph_proto is not null) | if(object_graph_proto is not null) | ||||
@@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint | |||||
/// <param name="call_with_mapped_captures"></param> | /// <param name="call_with_mapped_captures"></param> | ||||
/// <param name="object_graph_proto"></param> | /// <param name="object_graph_proto"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | ||||
{ | { | ||||
Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
@@ -174,25 +174,20 @@ public static class SaveUtilV1 | |||||
{ | { | ||||
var name = factory_data.name; | var name = factory_data.name; | ||||
var key = factory_data.checkpoint_key; | var key = factory_data.checkpoint_key; | ||||
var saveable_factory = factory_data.factory; | |||||
var maybe_saveable = factory_data.factory; | |||||
// TODO: oneflow python has a process with callable `saveable_factory`. | // TODO: oneflow python has a process with callable `saveable_factory`. | ||||
var maybe_saveable = saveable_factory; | |||||
IEnumerable<MySaveableObject> savesbles; | |||||
if (maybe_saveable is MySaveableObject) | |||||
{ | |||||
savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable }; | |||||
} | |||||
else if (maybe_saveable is Tensor) | |||||
List<MySaveableObject> saveables = new(); | |||||
if (maybe_saveable.DataType == typeof(MySaveableObject)) | |||||
{ | { | ||||
savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); | |||||
saveables.Add(maybe_saveable.GetValueB()); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
throw new TypeError("Unexpected type."); | |||||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); | |||||
} | } | ||||
foreach (var saveable in savesbles) | |||||
foreach (var saveable in saveables) | |||||
{ | { | ||||
if (!saveable.name.Contains(key)) | if (!saveable.name.Contains(key)) | ||||
{ | { | ||||
@@ -204,11 +199,11 @@ public static class SaveUtilV1 | |||||
// skip the process of PythonState | // skip the process of PythonState | ||||
named_saveable_objects.AddRange(savesbles); | |||||
named_saveable_objects.AddRange(saveables); | |||||
if(!fill_object_proto) continue; | if(!fill_object_proto) continue; | ||||
// skip the process of TrackableSaveable | |||||
// skip the process of `TrackableSaveable` because of lack of APIs. | |||||
object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | ||||
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | ||||
@@ -221,7 +216,7 @@ public static class SaveUtilV1 | |||||
public record class CheckpointFactoryData | public record class CheckpointFactoryData | ||||
( | ( | ||||
object factory, | |||||
Maybe<ResourceVariable, MySaveableObject> factory, | |||||
string name, | string name, | ||||
string checkpoint_key | string checkpoint_key | ||||
); | ); |
@@ -2,6 +2,7 @@ | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
@@ -18,13 +19,13 @@ public class TrackableView | |||||
_root_ref = obj; | _root_ref = obj; | ||||
} | } | ||||
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
obj._maybe_initialize_trackable(); | obj._maybe_initialize_trackable(); | ||||
Dictionary<string, Trackable> children = new(); | Dictionary<string, Trackable> children = new(); | ||||
// Note: in python the return type of `Trackable._trackable_children` is not fixed. | // Note: in python the return type of `Trackable._trackable_children` is not fixed. | ||||
// Therefore it uses `convert_to_trackable` to have an extra process. | // Therefore it uses `convert_to_trackable` to have an extra process. | ||||
foreach (var pair in obj._trackable_children(save_type)) | |||||
foreach (var pair in obj._trackable_children(save_type, cache)) | |||||
{ | { | ||||
children[pair.Key] = pair.Value; | children[pair.Key] = pair.Value; | ||||
} | } | ||||
@@ -33,7 +33,7 @@ public class TrackableSaver | |||||
} | } | ||||
private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | gather_serialized_tensors(Tensor? object_graph_tensor = null) | ||||
{ | { | ||||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | ||||
@@ -125,7 +125,7 @@ public class TrackableSaver | |||||
} | } | ||||
Dictionary<Tensor, string> feed_dict = new(); | Dictionary<Tensor, string> feed_dict = new(); | ||||
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||||
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); | |||||
if (checkpoint_number is not null) | if (checkpoint_number is not null) | ||||
{ | { | ||||
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | ||||
@@ -133,6 +133,7 @@ public class TrackableSaver | |||||
Tensor file_prefix_tensor; | Tensor file_prefix_tensor; | ||||
Tensor object_graph_tensor; | Tensor object_graph_tensor; | ||||
string file_prefix_to_save; | |||||
if (use_session) | if (use_session) | ||||
{ | { | ||||
if (_object_graph_feed_tensor is null) | if (_object_graph_feed_tensor is null) | ||||
@@ -145,16 +146,18 @@ public class TrackableSaver | |||||
object_graph_tensor = _object_graph_feed_tensor; | object_graph_tensor = _object_graph_feed_tensor; | ||||
file_prefix_tensor = _file_prefix_feed_tensor; | file_prefix_tensor = _file_prefix_feed_tensor; | ||||
feed_dict[file_prefix_tensor] = file_prefix; | feed_dict[file_prefix_tensor] = file_prefix; | ||||
file_prefix_to_save = ""; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
// In python there is `with ops.device("/cpu:0")`. | // In python there is `with ops.device("/cpu:0")`. | ||||
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | ||||
object_graph_tensor = null; | object_graph_tensor = null; | ||||
file_prefix_to_save = file_prefix; | |||||
} | } | ||||
var (save_path, new_feed_additions) = | var (save_path, new_feed_additions) = | ||||
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||||
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); | |||||
if (new_feed_additions is not null) | if (new_feed_additions is not null) | ||||
{ | { | ||||
@@ -6,9 +6,254 @@ using Tensorflow.Train; | |||||
using static Tensorflow.ApiDef.Types; | using static Tensorflow.ApiDef.Types; | ||||
using static Tensorflow.CostGraphDef.Types; | using static Tensorflow.CostGraphDef.Types; | ||||
using static Tensorflow.OptimizerOptions.Types; | using static Tensorflow.OptimizerOptions.Types; | ||||
using static Tensorflow.Binding; | |||||
using System.Text.RegularExpressions; | |||||
using System.Linq; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Training; | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
{ | { | ||||
/// <summary> | |||||
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. | |||||
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. | |||||
/// </summary> | |||||
public interface IFunctionHolder | |||||
{ | |||||
int ArgCount { get; } | |||||
object DynamicInvoke(params object[] args); | |||||
} | |||||
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder | |||||
{ | |||||
public int ArgCount => 0; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 1; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 2; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 3; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
public class Maybe<TA, TB> | |||||
{ | |||||
private TA? _valueA = default(TA); | |||||
private TB? _valueB = default(TB); | |||||
private Type _type; | |||||
private bool _assigned = false; | |||||
public Maybe(TA value) | |||||
{ | |||||
_valueA = value; | |||||
_type= typeof(TA); | |||||
_assigned = true; | |||||
} | |||||
public Maybe(TB value) | |||||
{ | |||||
_valueB = value; | |||||
_type = typeof(TB); | |||||
_assigned = true; | |||||
} | |||||
public Type DataType => _type; | |||||
public TA GetValueA() | |||||
{ | |||||
if(!_assigned || DataType != typeof(TA)) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
return _valueA; | |||||
} | |||||
public TB GetValueB() | |||||
{ | |||||
if (!_assigned || DataType != typeof(TB)) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
return _valueB; | |||||
} | |||||
public object GetValue() | |||||
{ | |||||
if (!_assigned) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
if(DataType == typeof(TA) && _valueA is not null) | |||||
{ | |||||
return _valueA; | |||||
} | |||||
else if(DataType == typeof(TB) && _valueB is not null) | |||||
{ | |||||
return _valueB; | |||||
} | |||||
else if(DataType == typeof(TA)) | |||||
{ | |||||
return _valueA; | |||||
} | |||||
else | |||||
{ | |||||
return _valueB; | |||||
} | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TA a) | |||||
{ | |||||
return new Maybe<TA, TB>(a); | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TB b) | |||||
{ | |||||
return new Maybe<TA, TB>(b); | |||||
} | |||||
} | |||||
internal class SingleDeviceSaver | |||||
{ | |||||
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict; | |||||
} | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
x => x.Key, x => x.Value.ToDictionary( | |||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
} | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
x => x.Key, x => x.Value.ToDictionary( | |||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
} | |||||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
List<string> tensor_names = new(); | |||||
List<Tensor> tensors = new(); | |||||
List<string> slice_specs = new(); | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice in tensor_slices) | |||||
{ | |||||
var slice_spec = slice.Key; | |||||
var maybe_tensor = slice.Value; | |||||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
{ | |||||
var spec = maybe_tensor.GetValueB(); | |||||
var tensor_value = spec.tensor; | |||||
if (tensor_value is not null) | |||||
{ | |||||
tensor_names.Add(spec.name); | |||||
tensors.Add(tensor_value); | |||||
slice_specs.Add(spec.slice_spec); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
var tensor = maybe_tensor.GetValueA(); | |||||
tensor_names.Add(checkpoint_key); | |||||
tensors.Add(tensor); | |||||
slice_specs.Add(slice_spec); | |||||
} | |||||
} | |||||
} | |||||
// TODO: specify the device. | |||||
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); | |||||
} | |||||
public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); | |||||
public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
List<string> tensor_names = new(); | |||||
List<TF_DataType> tensor_dtypes = new(); | |||||
List<string> slice_specs = new(); | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice in tensor_slices) | |||||
{ | |||||
var slice_spec = slice.Key; | |||||
var maybe_tensor = slice.Value; | |||||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
{ | |||||
var spec = maybe_tensor.GetValueB(); | |||||
tensor_dtypes.Add(spec.dtype); | |||||
slice_specs.Add(spec.slice_spec); | |||||
tensor_names.Add(spec.name); | |||||
} | |||||
else | |||||
{ | |||||
var tensor = maybe_tensor.GetValueA(); | |||||
tensor_dtypes.Add(tensor.dtype); | |||||
slice_specs.Add(slice_spec); | |||||
tensor_names.Add(checkpoint_key); | |||||
} | |||||
} | |||||
} | |||||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||||
// tf python has code `with ops.device(restore_device):` here. | |||||
tf.device(restore_device); // may be risky. | |||||
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||||
int idx = 0; | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice_spec in tensor_slices.Keys) | |||||
{ | |||||
var restored_tensor = restored_tensors[idx++]; | |||||
if (!restored_tensor_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
} | |||||
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; | |||||
} | |||||
} | |||||
return restored_tensor_dict; | |||||
} | |||||
public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Saves checkpoints directly from multiple devices. | /// Saves checkpoints directly from multiple devices. | ||||
/// Note that this is a low-level utility which stores Tensors in the keys | /// Note that this is a low-level utility which stores Tensors in the keys | ||||
@@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint | |||||
/// </summary> | /// </summary> | ||||
public class MultiDeviceSaver | public class MultiDeviceSaver | ||||
{ | { | ||||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors, | |||||
private Dictionary<string, SingleDeviceSaver> _single_device_savers; | |||||
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers; | |||||
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; | |||||
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys; | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||||
/// <param name="registered_savers"></param> | |||||
/// <param name="call_with_mapped_capture"></param> | |||||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | ||||
{ | { | ||||
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); | |||||
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>(); | |||||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||||
foreach(var pair in serialized_tensors) | |||||
{ | |||||
var obj = pair.Key; | |||||
var tensor_dict = pair.Value; | |||||
IFunctionHolder restore_fn; | |||||
if(obj is null) | |||||
{ | |||||
restore_fn = new FunctionHolder<object?>(() => null); | |||||
} | |||||
else | |||||
{ | |||||
restore_fn = null; | |||||
// TODO: implement obj._restore_from_tensors | |||||
} | |||||
foreach(var item in tensor_dict) | |||||
{ | |||||
var checkpoint_key = item.Key; | |||||
IDictionary<string, Tensor> spec_to_tensor; | |||||
if(item.Value.DataType != typeof(IDictionary<string, Tensor>)) | |||||
{ | |||||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||||
spec_to_tensor[""] = item.Value.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
spec_to_tensor = item.Value.GetValueB(); | |||||
} | |||||
foreach(var spec in spec_to_tensor) | |||||
{ | |||||
var slice_spec = spec.Key; | |||||
var tensor = spec.Value; | |||||
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) | |||||
{ | |||||
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + | |||||
$"slice spec. This is invalid because one will overwrite the " + | |||||
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); | |||||
} | |||||
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||||
// skip the process of device name because lack of API. | |||||
var host_device = tensor.Device; | |||||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
} | |||||
internal_dict[checkpoint_key][slice_spec] = tensor; | |||||
} | |||||
} | |||||
} | |||||
_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | |||||
_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>(); | |||||
if(registered_savers is not null && registered_savers.Count > 0) | |||||
{ | |||||
// TODO: complete the implementation. | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | } | ||||
public Operation? save(string file_prefix, CheckpointOptions? options= null) | |||||
public Operation save(string file_prefix, CheckpointOptions? options= null) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
tf.device("CPU"); // may be risky. | |||||
// TODO: optimize the implementation with new APIs adding to `string_ops`. | |||||
string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; | |||||
var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); | |||||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
Operation save_fn() | |||||
{ | |||||
List<Tensor> saved_prefixes= new(); | |||||
foreach(var saver in _registered_savers) | |||||
{ | |||||
// TODO: implementi it later. | |||||
throw new NotImplementedException(); | |||||
} | |||||
int num_shards = _single_device_savers.Count; | |||||
List<Operation> sharded_saves = new(); | |||||
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); | |||||
string? last_device = null; | |||||
int shard = 0; | |||||
foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) | |||||
{ | |||||
var device = pair.Key; | |||||
var saver = pair.Value; | |||||
last_device = device; | |||||
// skip the extra process of device name because of lack of API. | |||||
tf.device(device); | |||||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
saved_prefixes.Add(shard_prefix); | |||||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
} | |||||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||||
{ | |||||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||||
tf.device(merge_device); | |||||
return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); | |||||
} | |||||
} | |||||
if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) | |||||
{ | |||||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return save_fn(); | |||||
} | |||||
} | } | ||||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||||
public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); | |||||
public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
IDictionary<string, Operation> restore_func() | |||||
{ | |||||
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||||
Dictionary<string, Operation> restore_ops = new(); | |||||
foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | |||||
{ | |||||
var device = single_saver.Key; | |||||
var saver = single_saver.Value; | |||||
tf.device(device); | |||||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
foreach(var pair in restored_tensor_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var slice_and_tensor = pair.Value; | |||||
foreach(var item in slice_and_tensor) | |||||
{ | |||||
var slice_spec = item.Key; | |||||
var tensor = item.Value; | |||||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | |||||
{ | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
Dictionary<string, Tensor> dict = new(); | |||||
dict[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
} | |||||
restore_fn_input_count[restore_fn]--; | |||||
if (restore_fn_input_count[restore_fn] == 0) | |||||
{ | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
foreach(var input in restore_fn_inputs[restore_fn]) | |||||
{ | |||||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
} | |||||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
if(ret is IDictionary<string, Operation>) | |||||
{ | |||||
var dict = (IDictionary<string, Operation>)ret; | |||||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
foreach(var item in _registered_savers) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return restore_ops; | |||||
} | |||||
// TODO: complete the implementation. Currently skip it because of lack of API. | |||||
bool has_custom_device_saver = false; | |||||
if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) | |||||
{ | |||||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return restore_func(); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Serializes to a SaverDef referencing the current graph. | |||||
/// </summary> | |||||
public SaverDef to_proto() | |||||
{ | |||||
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); | |||||
var save_tensor = _traced_save(filename_tensor); | |||||
var restore_op = _traced_restore(filename_tensor).op; | |||||
return new SaverDef() | |||||
{ | |||||
FilenameTensorName = filename_tensor.name, | |||||
SaveTensorName = save_tensor.name, | |||||
RestoreOpName = restore_op.name, | |||||
Version = SaverDef.Types.CheckpointFormatVersion.V2 | |||||
}; | |||||
} | |||||
[AutoGraph] | |||||
private Tensor _traced_save(Tensor file_prefix) | |||||
{ | |||||
var save_op = save(file_prefix.StringData()[0]); | |||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(new object[]{ save_op })) | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
} | |||||
} | |||||
[AutoGraph] | |||||
private Tensor _traced_restore(Tensor file_prefix) | |||||
{ | |||||
var restore_op = restore(file_prefix.StringData()[0]); | |||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(new object[] { restore_op })) | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
} | |||||
} | |||||
private static Tensor registered_saver_filename(string filename, string saver_name) | |||||
{ | |||||
return tf.constant($"{filename}-{saver_name}"); | |||||
} | |||||
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
return filename_tensor; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | |||||
public interface ISerializedAttributes | |||||
{ | |||||
IDictionary<string, Trackable> Functions { get; } | |||||
IDictionary<string, Trackable> CheckpointableObjects { get; } | |||||
/// <summary> | |||||
/// Returns functions to attach to the root object during serialization. | |||||
/// </summary> | |||||
IDictionary<string, Trackable> FunctionsToSerialize { get; } | |||||
/// <summary> | |||||
/// Returns objects to attach to the root object during serialization. | |||||
/// </summary> | |||||
IDictionary<string, Trackable> ObjectsToSerialize{get; } | |||||
/// <summary> | |||||
/// Saves function dictionary, and validates dictionary values. | |||||
/// </summary> | |||||
/// <param name="function_dict"></param> | |||||
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict); | |||||
/// <summary> | |||||
/// Saves objects to a dictionary, and validates the values. | |||||
/// </summary> | |||||
/// <param name="object_dict"></param> | |||||
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict); | |||||
} | |||||
} |
@@ -1,6 +1,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -24,7 +25,7 @@ namespace Tensorflow.Train | |||||
} | } | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
if(save_type != SaveType.SAVEDMODEL) | if(save_type != SaveType.SAVEDMODEL) | ||||
{ | { | ||||
@@ -4,57 +4,130 @@ using Tensorflow.Train; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
namespace Tensorflow; | namespace Tensorflow; | ||||
public class AugmentedGraphView: ObjectGraphView | public class AugmentedGraphView: ObjectGraphView | ||||
{ | { | ||||
// private object _children_cache; | |||||
// private object _serialization_cache; | |||||
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache; | |||||
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache; | |||||
private List<string> _untraces_functions; | private List<string> _untraces_functions; | ||||
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions; | |||||
public AugmentedGraphView(Trackable root): base(root) | public AugmentedGraphView(Trackable root): base(root) | ||||
{ | { | ||||
_untraces_functions = new(); | |||||
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>(); | |||||
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>(); | |||||
_untraces_functions = new List<string>(); | |||||
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>(); | |||||
} | } | ||||
public void set_signature(object signature_map, object wrapped_functions) | |||||
public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions) | |||||
{ | { | ||||
// TODO: cache | |||||
list_children(Root); | list_children(Root); | ||||
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; | |||||
if (!_children_cache.ContainsKey(Root)) | |||||
{ | |||||
_children_cache[Root] = new Dictionary<string, Trackable>(); | |||||
} | |||||
_children_cache[Root][name] = signature_map; | |||||
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); | |||||
} | } | ||||
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) | |||||
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | { | ||||
Dictionary<string, Trackable> children = new(); | |||||
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) | |||||
if(serialization_cache is not null) | |||||
{ | |||||
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); | |||||
} | |||||
if (!_children_cache.ContainsKey(obj)) | |||||
{ | |||||
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>(); | |||||
_children_cache[obj] = children; | |||||
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) | |||||
{ | |||||
var name = pair.Name; | |||||
var child = pair.Refer; | |||||
if(child is ConcreteFunction) | |||||
{ | |||||
child = maybe_uncache_variable_captures((ConcreteFunction)child); | |||||
} | |||||
children[name] = child; | |||||
} | |||||
if (obj is Function && children.Count == 0) | |||||
{ | |||||
_untraces_functions.Add(((Function)obj).Name); | |||||
} | |||||
} | |||||
List<TrackableReference> res = new(); | |||||
foreach(var pair in _children_cache[obj]) | |||||
{ | { | ||||
var name = pair.Name; | |||||
var child = pair.Refer; | |||||
children[name] = child; | |||||
res.Add(new TrackableReference(pair.Key, pair.Value)); | |||||
} | } | ||||
if (obj is Function && children.Count == 0) | |||||
return res; | |||||
} | |||||
private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) | |||||
{ | |||||
if (_wrapped_functions.ContainsKey(concrete_function)) | |||||
{ | { | ||||
_untraces_functions.Add(((Function)obj).Name); | |||||
return _wrapped_functions[concrete_function]; | |||||
} | } | ||||
// skip the process here because of lack of feature. | |||||
// In the future, we may add an attribute which could specify if the variable is supposed to be cached. | |||||
//foreach(var capture in concrete_function.CapturedInputs) | |||||
//{ | |||||
return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||||
//} | |||||
return concrete_function; | |||||
} | } | ||||
public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | ||||
{ | { | ||||
// TODO: implement it if needed. | |||||
Trackable get_merged_trackable(Trackable x) | |||||
{ | |||||
// TODO: complete it with new definitions `Asset` and `TrackableConstant`. | |||||
return x; | |||||
} | |||||
var trackable_objects = base.breadth_first_traversal(); | |||||
foreach(var obj in _children_cache.Keys) | |||||
{ | |||||
// skip the deletion of cache (maybe do it later). | |||||
foreach(var pair in _children_cache[obj]) | |||||
{ | |||||
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); | |||||
} | |||||
} | |||||
return base.breadth_first_traversal(); | return base.breadth_first_traversal(); | ||||
} | } | ||||
public List<(string, Trackable)> list_dependencies(Trackable obj) | public List<(string, Trackable)> list_dependencies(Trackable obj) | ||||
{ | { | ||||
// TODO: deal with cache. | |||||
return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); | |||||
IDictionary<string, Trackable> children; | |||||
if (!_children_cache.ContainsKey(obj)) | |||||
{ | |||||
children= new Dictionary<string, Trackable>(); | |||||
} | |||||
else | |||||
{ | |||||
children= _children_cache[obj]; | |||||
} | |||||
List<(string, Trackable)> res = new(); | |||||
foreach(var pair in obj.deserialization_dependencies(children)) | |||||
{ | |||||
res.Add((pair.Key, pair.Value)); | |||||
} | |||||
return res; | |||||
} | } | ||||
public Trackable get_child(Trackable obj, string name) | public Trackable get_child(Trackable obj, string name) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
return _children_cache[obj][name]; | |||||
} | } | ||||
} | } |
@@ -141,16 +141,16 @@ public class SaveableView | |||||
foreach (var node in _nodes) | foreach (var node in _nodes) | ||||
{ | { | ||||
var node_id = _node_ids[node]; | var node_id = _node_ids[node]; | ||||
List<int> deps = new(); | |||||
List<int> deps = new List<int>(); | |||||
dependency_map.Add(node_id, deps); | |||||
// TODO: deal with captured tensor. | // TODO: deal with captured tensor. | ||||
string node_path; | |||||
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | ||||
{ | { | ||||
if (!_node_ids.ContainsKey(dep)) | if (!_node_ids.ContainsKey(dep)) | ||||
{ | { | ||||
node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||||
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||||
throw new ValueError( | throw new ValueError( | ||||
$"Found an untracked dependency. Object {node_path} depends on {dep}, " + | $"Found an untracked dependency. Object {node_path} depends on {dep}, " + | ||||
$"but this dependency isn't listed as a child. Please track this child by " + | $"but this dependency isn't listed as a child. Please track this child by " + | ||||
@@ -24,7 +24,7 @@ public static partial class SavedModelUtils | |||||
}.Select(x => (int)x); | }.Select(x => (int)x); | ||||
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | ||||
string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||||
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||||
{ | { | ||||
if (options is null) | if (options is null) | ||||
{ | { | ||||
@@ -41,9 +41,9 @@ public static partial class SavedModelUtils | |||||
if (!experimental_skip_checkpoint) | if (!experimental_skip_checkpoint) | ||||
{ | { | ||||
Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); | |||||
SavedModelUtils.get_or_create_variables_dir(export_dir); | |||||
CheckpointOptions ckpt_options = new(options.experimental_io_device); | CheckpointOptions ckpt_options = new(options.experimental_io_device); | ||||
object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||||
object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||||
} | } | ||||
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | ||||
@@ -67,7 +67,7 @@ public static partial class SavedModelUtils | |||||
} | } | ||||
var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | ||||
File.WriteAllText(path, saved_model.ToString()); | |||||
File.WriteAllBytes(path, saved_model.ToByteArray()); | |||||
if (options.save_debug_info) | if (options.save_debug_info) | ||||
{ | { | ||||
@@ -81,7 +81,7 @@ public static partial class SavedModelUtils | |||||
private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | ||||
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | ||||
IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||||
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||||
{ | { | ||||
if (ops.inside_function()) | if (ops.inside_function()) | ||||
{ | { | ||||
@@ -95,9 +95,9 @@ public static partial class SavedModelUtils | |||||
} | } | ||||
AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | ||||
if (signatures is not null) | |||||
if (signatures is null) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); | |||||
} | } | ||||
// TODO: process of aignatures and wrapped_functions | // TODO: process of aignatures and wrapped_functions | ||||
@@ -125,7 +125,7 @@ public static partial class SavedModelUtils | |||||
} | } | ||||
private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | ||||
IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist, | |||||
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist, | |||||
bool save_custom_gradients) | bool save_custom_gradients) | ||||
{ | { | ||||
var resource_initializers = saveable_view.get_concrete_resource_initializers(); | var resource_initializers = saveable_view.get_concrete_resource_initializers(); | ||||
@@ -1,15 +1,84 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
namespace Tensorflow; | namespace Tensorflow; | ||||
public static class SignatureSerializationUtils | |||||
{ | |||||
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; | |||||
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; | |||||
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; | |||||
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures) | |||||
{ | |||||
var signature_map = new SignatureMap(); | |||||
foreach (var pair in signatures) | |||||
{ | |||||
var name = pair.Key; | |||||
var func = pair.Value; | |||||
Debug.Assert(func is ConcreteFunction); | |||||
// TODO: assert the `func.structured_outputs` and arg_keywords. | |||||
signature_map._add_signature(name, (ConcreteFunction)func); | |||||
} | |||||
return signature_map; | |||||
} | |||||
public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) | |||||
{ | |||||
var children = graph_view.list_children(graph_view.Root); | |||||
List<Trackable> possible_signatures = new(); | |||||
foreach (var item in children) | |||||
{ | |||||
var name = item.Name; | |||||
var child = item.Refer; | |||||
if(child is not (Function or ConcreteFunction)) | |||||
{ | |||||
continue; | |||||
} | |||||
if(name == DEFAULT_SIGNATURE_ATTR) | |||||
{ | |||||
Debug.Assert(child is ConcreteFunction); | |||||
return (ConcreteFunction)child; | |||||
} | |||||
ConcreteFunction concrete = get_signature(child); | |||||
if(concrete is not null && valid_signature(concrete)) | |||||
{ | |||||
possible_signatures.Add(concrete); | |||||
} | |||||
} | |||||
if(possible_signatures.Count == 1) | |||||
{ | |||||
var signature = get_signature(possible_signatures[0]); | |||||
if(signature is not null && valid_signature(signature)) | |||||
{ | |||||
return signature; | |||||
} | |||||
} | |||||
return null; | |||||
} | |||||
private static ConcreteFunction get_signature(Trackable function) | |||||
{ | |||||
// TODO: implement it. | |||||
return null; | |||||
} | |||||
private static bool valid_signature(ConcreteFunction concreate_function) | |||||
{ | |||||
// TODO: implement it. | |||||
return false; | |||||
} | |||||
} | |||||
public class SignatureMap: Trackable | public class SignatureMap: Trackable | ||||
{ | { | ||||
private Dictionary<string, Function> _signatures; | |||||
private Dictionary<string, ConcreteFunction> _concrete_signatures; | |||||
private Dictionary<string, Trackable> _signatures; | |||||
public SignatureMap() | public SignatureMap() | ||||
{ | { | ||||
@@ -18,7 +87,7 @@ public class SignatureMap: Trackable | |||||
public void _add_signature(string name, ConcreteFunction concrete_function) | public void _add_signature(string name, ConcreteFunction concrete_function) | ||||
{ | { | ||||
_concrete_signatures[name] = concrete_function; | |||||
_signatures[name] = concrete_function; | |||||
} | } | ||||
public void _add_signature(string name, Function concrete_function) | public void _add_signature(string name, Function concrete_function) | ||||
@@ -26,33 +95,13 @@ public class SignatureMap: Trackable | |||||
_signatures[name] = concrete_function; | _signatures[name] = concrete_function; | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
if (save_type != SaveType.SAVEDMODEL) | if (save_type != SaveType.SAVEDMODEL) | ||||
{ | { | ||||
return new Dictionary<string, Trackable>(); | return new Dictionary<string, Trackable>(); | ||||
} | } | ||||
Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); | |||||
foreach (var pair in _concrete_signatures) | |||||
{ | |||||
res[pair.Key] = pair.Value; | |||||
} | |||||
return res; | |||||
} | |||||
public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures) | |||||
{ | |||||
var signature_map = new SignatureMap(); | |||||
foreach (var pair in signatures) | |||||
{ | |||||
var name = pair.Key; | |||||
var func = pair.Value; | |||||
// TODO: assert the arg_keywords | |||||
signature_map._add_signature(name, func); | |||||
} | |||||
return signature_map; | |||||
return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||||
} | } | ||||
} | } |
@@ -16,18 +16,38 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static class saveable_object_util | |||||
/// <summary> | |||||
/// A SaveableObject that defines `Trackable` checkpointing steps. | |||||
/// </summary> | |||||
public class TrackableSaveable : MySaveableObject | |||||
{ | { | ||||
public class TrackableSaveable: MySaveableObject | |||||
private string _prefix; | |||||
private IEnumerable<string> _local_names; | |||||
private Trackable _trackable; | |||||
private bool _call_with_mapped_captures; | |||||
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. | |||||
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names, | |||||
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) | |||||
{ | { | ||||
_prefix = prefix; | |||||
_trackable = obj; | |||||
_local_names = local_names; | |||||
_call_with_mapped_captures = call_with_mapped_captures; | |||||
} | } | ||||
// TODO: complete this class. | |||||
} | |||||
public static class saveable_object_util | |||||
{ | |||||
/// <summary> | /// <summary> | ||||
/// Returns the variables and names that will be used for a Saver. | /// Returns the variables and names that will be used for a Saver. | ||||
/// </summary> | /// </summary> | ||||
@@ -57,7 +77,7 @@ namespace Tensorflow | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Create `SaveableObject`s from an operation. | |||||
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||||
/// </summary> | /// </summary> | ||||
/// <param name="op"></param> | /// <param name="op"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
@@ -79,6 +99,74 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Create `SaveableObject`s from an operation. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name) | |||||
{ | |||||
// The `op` maybe `Variable` or `Trackable`. | |||||
if (obj is BaseResourceVariable) | |||||
{ | |||||
var variable = obj as BaseResourceVariable; | |||||
if (variable.InGraphMode) | |||||
{ | |||||
yield return new ResourceVariableSaveable(variable.GraphElement, "", name); | |||||
} | |||||
else | |||||
{ | |||||
Debug.Assert(variable is ResourceVariable); | |||||
yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
foreach(var pair in saveable_objects_from_trackable(obj)) | |||||
{ | |||||
var attr = pair.Key; | |||||
var factory = pair.Value; | |||||
string full_name; | |||||
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) | |||||
{ | |||||
full_name = name; | |||||
} | |||||
else | |||||
{ | |||||
full_name = name + "_" + attr; | |||||
} | |||||
if(factory.DataType == typeof(ResourceVariable)) | |||||
{ | |||||
var variable = factory.GetValueA(); | |||||
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
{ | |||||
yield return op; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
var variable = factory.GetValueB(); | |||||
foreach (var op in saveable_objects_for_op(variable, variable.name)) | |||||
{ | |||||
yield return op; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Create `SaveableObject`s from an operation. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name) | |||||
{ | |||||
yield return obj; | |||||
} | |||||
public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | ||||
{ | { | ||||
op_list = op_list.OrderBy(x => x.Name).ToArray(); | op_list = op_list.OrderBy(x => x.Name).ToArray(); | ||||
@@ -127,16 +215,55 @@ namespace Tensorflow | |||||
return names_to_saveables; | return names_to_saveables; | ||||
} | } | ||||
public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj) | |||||
public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||||
{ | { | ||||
// TODO: complete the implementation. | |||||
return obj.gather_saveables_for_checkpoint(); | |||||
// skip the process of type `PythonState` | |||||
if (trackable_has_serialize_to_tensor(obj)) | |||||
{ | |||||
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||||
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||||
var tensor_dict = obj.serialize_to_tensors(); | |||||
List<SaveSpec> specs = new(); | |||||
List<string> local_names = new(); | |||||
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||||
foreach(var pair in tensor_dict) | |||||
{ | |||||
var tensor_name = pair.Key; | |||||
var maybe_tensor = pair.Value; | |||||
local_names.Add(tensor_name); | |||||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||||
IDictionary<string, Tensor> internal_dict; | |||||
if(maybe_tensor.DataType == typeof(Tensor)) | |||||
{ | |||||
internal_dict= new Dictionary<string, Tensor>(); | |||||
internal_dict[""] = maybe_tensor.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
internal_dict = maybe_tensor.GetValueB(); | |||||
} | |||||
foreach(var item in internal_dict) | |||||
{ | |||||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||||
} | |||||
} | |||||
Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new(); | |||||
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
return res; | |||||
} | |||||
else | |||||
{ | |||||
return obj.gather_saveables_for_checkpoint(); | |||||
} | |||||
} | } | ||||
public static bool trackable_has_serialize_to_tensor(Trackable obj) | public static bool trackable_has_serialize_to_tensor(Trackable obj) | ||||
{ | { | ||||
// TODO: implement it. | |||||
return false; | |||||
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); | |||||
} | } | ||||
internal static string convert_to_string(string x) | internal static string convert_to_string(string x) | ||||
@@ -158,27 +285,28 @@ namespace Tensorflow | |||||
public Trackable Obj => _obj; | public Trackable Obj => _obj; | ||||
public IList<MySaveableObject> mySaveables=> _saveables; | public IList<MySaveableObject> mySaveables=> _saveables; | ||||
public override IDictionary<string, object> serialize_to_tensors() | |||||
public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
{ | { | ||||
return saveable_objects_to_tensor_dict(_saveables); | |||||
return saveable_object_to_tensor_dict(_saveables); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Converts a list of SaveableObjects to a tensor dictionary. | /// Converts a list of SaveableObjects to a tensor dictionary. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="saveables"></param> | /// <param name="saveables"></param> | ||||
public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
{ | { | ||||
Dictionary<string, object> tensor_dict = new(); | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
{ | { | ||||
foreach(var spec in saveable.specs) | foreach(var spec in saveable.specs) | ||||
{ | { | ||||
// skip the check that if `spec` is callable. | |||||
var name = saveable_object_util.convert_to_string(spec.name); | var name = saveable_object_util.convert_to_string(spec.name); | ||||
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | ||||
if (!string.IsNullOrEmpty(slice_spec)) | if (!string.IsNullOrEmpty(slice_spec)) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -16,7 +16,10 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -39,8 +42,8 @@ namespace Tensorflow.Train | |||||
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | ||||
protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | |||||
new Dictionary<string, ResourceVariable>(); | |||||
protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||||
new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||||
private bool _manual_tracking = true; | private bool _manual_tracking = true; | ||||
private static Trackable _none = new Function(); | private static Trackable _none = new Function(); | ||||
@@ -94,9 +97,13 @@ namespace Tensorflow.Train | |||||
// assign again. It will add this variable to our dependencies, and if there | // assign again. It will add this variable to our dependencies, and if there | ||||
// is a non-trivial restoration queued, it will handle that. This also | // is a non-trivial restoration queued, it will handle that. This also | ||||
// handles slot variables. | // handles slot variables. | ||||
if (!args.Overwrite || new_variable is RefVariable) | |||||
return _track_checkpointable(new_variable, name: args.Name, | |||||
overwrite: args.Overwrite); | |||||
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | |||||
{ | |||||
var temp = new_variable as Trackable; | |||||
var res = _track_trackable(temp, args.Name, args.Overwrite); | |||||
Debug.Assert(res is IVariableV1); | |||||
return res as IVariableV1; | |||||
} | |||||
else | else | ||||
return new_variable; | return new_variable; | ||||
} | } | ||||
@@ -122,13 +129,16 @@ namespace Tensorflow.Train | |||||
/// </summary> | /// </summary> | ||||
public void _maybe_initialize_trackable() | public void _maybe_initialize_trackable() | ||||
{ | { | ||||
if(_unconditional_checkpoint_dependencies is not null) | |||||
{ | |||||
return; | |||||
} | |||||
_self_update_uid = -1; | _self_update_uid = -1; | ||||
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); | _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | ||||
_unconditional_dependency_names = new Dictionary<string, Trackable>(); | _unconditional_dependency_names = new Dictionary<string, Trackable>(); | ||||
} | } | ||||
// TODO: cache | |||||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||||
{ | { | ||||
_maybe_initialize_trackable(); | _maybe_initialize_trackable(); | ||||
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | ||||
@@ -139,8 +149,8 @@ namespace Tensorflow.Train | |||||
_maybe_initialize_trackable(); | _maybe_initialize_trackable(); | ||||
if (!_manual_tracking) return trackable; | if (!_manual_tracking) return trackable; | ||||
var new_reference = new TrackableReference(name, trackable); | var new_reference = new TrackableReference(name, trackable); | ||||
var current_object = _lookupup_dependency(name); | |||||
var current_object = _lookup_dependency(name); | |||||
if(current_object is null) | if(current_object is null) | ||||
{ | { | ||||
_unconditional_checkpoint_dependencies.Add(new_reference); | _unconditional_checkpoint_dependencies.Add(new_reference); | ||||
@@ -170,7 +180,7 @@ namespace Tensorflow.Train | |||||
// TODO: complete the implementation. | // TODO: complete the implementation. | ||||
} | } | ||||
public virtual Trackable? _lookupup_dependency(string name) | |||||
public virtual Trackable? _lookup_dependency(string name) | |||||
{ | { | ||||
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; | if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; | ||||
else return null; | else return null; | ||||
@@ -199,8 +209,8 @@ namespace Tensorflow.Train | |||||
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | ||||
} | } | ||||
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null, | |||||
IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null) | |||||
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map, | |||||
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null) | |||||
{ | { | ||||
var (self_object_map, self_tensor_map) = map_resources(options); | var (self_object_map, self_tensor_map) = map_resources(options); | ||||
foreach (var pair in self_object_map) | foreach (var pair in self_object_map) | ||||
@@ -215,9 +225,17 @@ namespace Tensorflow.Train | |||||
return self_tensor_map.Keys.ToList(); | return self_tensor_map.Keys.ToList(); | ||||
} | } | ||||
public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint() | |||||
public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
{ | { | ||||
return _self_saveable_object_factories; | |||||
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||||
{ | |||||
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return _self_saveable_object_factories; | |||||
} | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -229,7 +247,7 @@ namespace Tensorflow.Train | |||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
/// <exception cref="NotImplementedException"></exception> | /// <exception cref="NotImplementedException"></exception> | ||||
public virtual IDictionary<string, object> serialize_to_tensors() | |||||
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -1,4 +1,5 @@ | |||||
using System.Collections.Generic; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
@@ -22,7 +23,7 @@ public static class TrackableUtils | |||||
private static string _ESCAPE_CHAR = "."; | private static string _ESCAPE_CHAR = "."; | ||||
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | ||||
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | ||||
private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||||
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||||
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | ||||
{ | { | ||||
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | ||||
@@ -145,4 +146,27 @@ public static class TrackableUtils | |||||
return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | ||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. | |||||
/// </summary> | |||||
/// <param name="key"></param> | |||||
/// <param name="prefix"></param> | |||||
/// <returns></returns> | |||||
public static string extract_local_name(string key, string? prefix = null) | |||||
{ | |||||
if(prefix is null) | |||||
{ | |||||
prefix = ""; | |||||
} | |||||
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; | |||||
try | |||||
{ | |||||
return key.Substring(key.IndexOf(search_key) + search_key.Length); | |||||
} | |||||
catch(ArgumentOutOfRangeException) | |||||
{ | |||||
return key; | |||||
} | |||||
} | |||||
} | } |
@@ -9,6 +9,7 @@ using System.Runtime.InteropServices; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.ApiDef.Types; | using static Tensorflow.ApiDef.Types; | ||||
@@ -243,7 +244,7 @@ namespace Tensorflow.Training | |||||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | _last_wrapped_list_snapshot = new List<Trackable>(_storage); | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
check_external_modification(); | check_external_modification(); | ||||
if (_non_append_mutation_value) | if (_non_append_mutation_value) | ||||
@@ -4,6 +4,8 @@ using Tensorflow.Eager; | |||||
using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Collections.Generic; | |||||
using Tensorflow.ModelSaving; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -20,6 +22,7 @@ namespace Tensorflow | |||||
public string UniqueId => _unique_id; | public string UniqueId => _unique_id; | ||||
protected bool _in_graph_mode; | protected bool _in_graph_mode; | ||||
internal bool InGraphMode => _in_graph_mode; | |||||
protected bool _trainable; | protected bool _trainable; | ||||
public bool Trainable => _trainable; | public bool Trainable => _trainable; | ||||
@@ -17,7 +17,9 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -235,5 +237,12 @@ namespace Tensorflow | |||||
{ | { | ||||
return _graph_element.eval(session); | return _graph_element.eval(session); | ||||
} | } | ||||
public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
{ | |||||
var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||||
return res; | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine | |||||
return output_tensors; | return output_tensors; | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | ||||
.ToDictionary(x => x.Key, x => x.Value); | .ToDictionary(x => x.Key, x => x.Value); | ||||
@@ -1,4 +1,5 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
@@ -9,16 +10,16 @@ public abstract partial class Layer | |||||
{ | { | ||||
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | ||||
public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||||
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||||
public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
IDictionary<string, Trackable> children; | IDictionary<string, Trackable> children; | ||||
if (save_type == SaveType.SAVEDMODEL) | if (save_type == SaveType.SAVEDMODEL) | ||||
{ | { | ||||
// TODO: deal with cache. | |||||
Debug.Assert(cache is not null); | |||||
children = TrackableSavedModelSaver.trackable_children(cache); | children = TrackableSavedModelSaver.trackable_children(cache); | ||||
} | } | ||||
else | else | ||||
@@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine | |||||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ||||
public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
public Tensor[] input => inboundNodes[0].input_tensors; | |||||
public Tensor[] input | |||||
{ | |||||
get | |||||
{ | |||||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||||
{ | |||||
return inboundNodes[0].input_tensors; | |||||
} | |||||
return null; | |||||
} | |||||
} | |||||
public Dictionary<int, List<INode>> NodesByDepth { get; set; } | public Dictionary<int, List<INode>> NodesByDepth { get; set; } | ||||
public Shape OutputShape => inboundNodes[0].Outputs.shape; | |||||
public Shape OutputShape | |||||
{ | |||||
get | |||||
{ | |||||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||||
{ | |||||
return inboundNodes[0].Outputs.shape; | |||||
} | |||||
return null; | |||||
} | |||||
} | |||||
protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
bool include_optimizer = true, | bool include_optimizer = true, | ||||
string save_format = "tf", | string save_format = "tf", | ||||
SaveOptions? options = null, | SaveOptions? options = null, | ||||
IDictionary<string, ConcreteFunction>? signatures = null, | |||||
ConcreteFunction? signatures = null, | |||||
bool save_traces = true) | bool save_traces = true) | ||||
{ | { | ||||
if (save_format != "pb") | if (save_format != "pb") | ||||
@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | { | ||||
if(save_type == SaveType.SAVEDMODEL) | if(save_type == SaveType.SAVEDMODEL) | ||||
{ | { | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||||
public partial class KerasSavedModelUtils | public partial class KerasSavedModelUtils | ||||
{ | { | ||||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures, | |||||
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||||
SaveOptions? options, bool save_traces = true) | SaveOptions? options, bool save_traces = true) | ||||
{ | { | ||||
if (!overwrite && File.Exists(filepath)) | if (!overwrite && File.Exists(filepath)) | ||||
@@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils | |||||
} | } | ||||
var metadata = generate_keras_metadata(saved_nodes, node_paths); | var metadata = generate_keras_metadata(saved_nodes, node_paths); | ||||
using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, | |||||
FileAccess.Write)) | |||||
{ | |||||
var writer = new StreamWriter(f); | |||||
writer.Write(metadata.ToString()); | |||||
} | |||||
File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); | |||||
if (!include_optimizer) | if (!include_optimizer) | ||||
{ | { | ||||
@@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils | |||||
/// <param name="layer"></param> | /// <param name="layer"></param> | ||||
/// <param name="serialization_cache"></param> | /// <param name="serialization_cache"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache) | |||||
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | ||||
@@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils | |||||
/// <param name="layer"></param> | /// <param name="layer"></param> | ||||
/// <param name="serialization_cache"></param> | /// <param name="serialization_cache"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache) | |||||
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
// TODO: deal with type `RevivedLayer` and `Sequential`. | // TODO: deal with type `RevivedLayer` and `Sequential`. | ||||
@@ -18,12 +18,12 @@ public abstract class SavedModelSaver | |||||
public abstract string TrackingMetadata { get; } | public abstract string TrackingMetadata { get; } | ||||
public abstract IDictionary<string, Trackable> objects_to_serialize( | public abstract IDictionary<string, Trackable> objects_to_serialize( | ||||
IDictionary<string, object> serialization_cache); | |||||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache); | |||||
public abstract IDictionary<string, Trackable> functions_to_serialize( | public abstract IDictionary<string, Trackable> functions_to_serialize( | ||||
IDictionary<string, object> serialization_cache); | |||||
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache); | |||||
public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | |||||
public IDictionary<string, Trackable> trackable_children(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
if (!KerasSavedModelUtils.ShouldHaveTraces) | if (!KerasSavedModelUtils.ShouldHaveTraces) | ||||
{ | { | ||||
@@ -31,7 +31,6 @@ public abstract class SavedModelSaver | |||||
} | } | ||||
var children = objects_to_serialize(serialization_cache); | var children = objects_to_serialize(serialization_cache); | ||||
return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | ||||
.ToDictionary(x => x.Key, x => x.Value); | .ToDictionary(x => x.Key, x => x.Value); | ||||
} | } | ||||
@@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
get => Constants.LAYER_IDENTIFIER; | get => Constants.LAYER_IDENTIFIER; | ||||
} | } | ||||
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||||
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
return get_serialized_attributes(serialization_cache).ObjectsToSerialize; | return get_serialized_attributes(serialization_cache).ObjectsToSerialize; | ||||
} | } | ||||
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||||
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
return get_serialized_attributes(serialization_cache).FunctionsToSerialize; | return get_serialized_attributes(serialization_cache).FunctionsToSerialize; | ||||
} | } | ||||
@@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
/// Generates or retrieves serialized attributes from cache. | /// Generates or retrieves serialized attributes from cache. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="serialization_cache"></param> | /// <param name="serialization_cache"></param> | ||||
protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||||
protected ISerializedAttributes get_serialized_attributes(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
// TODO: deal with cache. | // TODO: deal with cache. | ||||
IDictionary<Trackable, ISerializedAttributes> keras_cache; | |||||
if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) | |||||
{ | |||||
keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; | |||||
} | |||||
else | |||||
{ | |||||
serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary<Trackable, ISerializedAttributes>(); | |||||
} | |||||
if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; | |||||
var serialized_attr = SerializedAttributes.Create(_obj); | |||||
var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); | |||||
// TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. | // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. | ||||
if (KerasSavedModelUtils.should_skip_serialization(_obj)) | if (KerasSavedModelUtils.should_skip_serialization(_obj)) | ||||
@@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
/// Returns dictionary of serialized attributes. | /// Returns dictionary of serialized attributes. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="serialization_cache"></param> | /// <param name="serialization_cache"></param> | ||||
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache) | |||||
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
{ | { | ||||
var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | ||||
var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | ||||
@@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
metadata["trainable"] = _obj.Trainable; | metadata["trainable"] = _obj.Trainable; | ||||
// metadata["expects_training_arg"] = _obj._expects_training_arg; | // metadata["expects_training_arg"] = _obj._expects_training_arg; | ||||
// metadata["dtype"] = policy.serialize(_obj._dtype_policy) | // metadata["dtype"] = policy.serialize(_obj._dtype_policy) | ||||
metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); | |||||
metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); | |||||
// metadata["stateful"] = _obj.stateful; | // metadata["stateful"] = _obj.stateful; | ||||
// metadata["must_restore_from_config"] = _obj.must_restore_from_config; | // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | ||||
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | ||||
@@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
} | } | ||||
} | } | ||||
public static LayerConfig get_serialized(Layer obj) | |||||
public static IDictionary<string, object> get_serialized(Layer obj) | |||||
{ | { | ||||
return generic_utils.serialize_keras_object(obj); | |||||
// TODO: complete the implmentation (need to revise `get_config`). | |||||
return new Dictionary<string, object>(); | |||||
//return generic_utils.serialize_keras_object(obj); | |||||
} | } | ||||
} | } |
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
/// <summary> | /// <summary> | ||||
/// Class that tracks and validates all serialization attributes. | /// Class that tracks and validates all serialization attributes. | ||||
/// </summary> | /// </summary> | ||||
public abstract class SerializedAttributes | |||||
public abstract class SerializedAttributes: ISerializedAttributes | |||||
{ | { | ||||
protected IDictionary<string, Trackable?> _object_dict; | protected IDictionary<string, Trackable?> _object_dict; | ||||
protected IDictionary<string, Trackable?> _function_dict; | protected IDictionary<string, Trackable?> _function_dict; | ||||
@@ -50,11 +50,11 @@ public class SaveTest | |||||
{ | { | ||||
TrainDir = "mnist", | TrainDir = "mnist", | ||||
OneHot = false, | OneHot = false, | ||||
ValidationSize = 50000, | |||||
ValidationSize = 0, | |||||
}).Result; | }).Result; | ||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
model.save("", save_format:"pb"); | |||||
model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); | |||||
} | } | ||||
} | } |