Browse Source

Add lacked implementations (mainly MultiDeviceSaver).

pull/976/head
AsakusaRinne 2 years ago
parent
commit
83906b8f79
30 changed files with 1037 additions and 161 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  2. +5
    -4
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  3. +12
    -11
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  4. +11
    -16
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  5. +3
    -2
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  6. +6
    -3
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  7. +510
    -5
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  8. +35
    -0
      src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs
  9. +2
    -1
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  10. +91
    -18
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs
  11. +3
    -3
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  12. +8
    -8
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  13. +74
    -25
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  14. +142
    -14
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  15. +33
    -15
      src/TensorFlowNET.Core/Training/Trackable.cs
  16. +26
    -2
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  17. +2
    -1
      src/TensorFlowNET.Core/Training/data_structures.cs
  18. +3
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  19. +9
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  20. +2
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  21. +4
    -3
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  22. +22
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  23. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  24. +2
    -1
      src/TensorFlowNET.Keras/Engine/Model.cs
  25. +2
    -7
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  26. +2
    -2
      src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs
  27. +3
    -4
      src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs
  28. +20
    -8
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  30. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/SaveTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -1,5 +1,5 @@
namespace Tensorflow.Checkpoint;

public record class CheckpointOptions(
string experimental_io_device = null,
string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);

+ 5
- 4
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using Serilog.Debugging;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;
@@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable
return new ObjectGraphView(Root, _attached_dependencies);
}

public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
List<TrackableReference> res = base.children(obj, save_type)
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value.
if (obj == Root && _attached_dependencies is not null)
@@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable
return res;
}
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer);
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
}
public IEnumerable<TrackableReference>? AttachedDependencies


+ 12
- 11
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
var trackable = td.object_to_save;
IDictionary<string, object> tensor_dict;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint
return serialized_tensors;
}

private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, object> ret_tensor_dict;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
@@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: revise the types and complete it
Dictionary<string, object> tensor_dict = new();
// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor is SaveSpec)
if(maybe_tensor.GetValueA() is SaveSpec)
{
((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
}

if(object_graph_proto is not null)
@@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();


+ 11
- 16
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -174,25 +174,20 @@ public static class SaveUtilV1
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var saveable_factory = factory_data.factory;
var maybe_saveable = factory_data.factory;
// TODO: oneflow python has a process with callable `saveable_factory`.
var maybe_saveable = saveable_factory;
IEnumerable<MySaveableObject> savesbles;
if (maybe_saveable is MySaveableObject)
{
savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable };
}
else if (maybe_saveable is Tensor)
List<MySaveableObject> saveables = new();
if (maybe_saveable.DataType == typeof(MySaveableObject))
{
savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key);
saveables.Add(maybe_saveable.GetValueB());
}
else
{
throw new TypeError("Unexpected type.");
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key));
}

foreach (var saveable in savesbles)
foreach (var saveable in saveables)
{
if (!saveable.name.Contains(key))
{
@@ -204,11 +199,11 @@ public static class SaveUtilV1
// skip the process of PythonState
named_saveable_objects.AddRange(savesbles);
named_saveable_objects.AddRange(saveables);
if(!fill_object_proto) continue;
// skip the process of TrackableSaveable
// skip the process of `TrackableSaveable` because of lack of APIs.

object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
@@ -221,7 +216,7 @@ public static class SaveUtilV1

public record class CheckpointFactoryData
(
object factory,
Maybe<ResourceVariable, MySaveableObject> factory,
string name,
string checkpoint_key
);

+ 3
- 2
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Train;
using System.Collections.Generic;
using System.IO;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow.Checkpoint;

@@ -18,13 +19,13 @@ public class TrackableView
_root_ref = obj;
}
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
obj._maybe_initialize_trackable();
Dictionary<string, Trackable> children = new();
// Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process.
foreach (var pair in obj._trackable_children(save_type))
foreach (var pair in obj._trackable_children(save_type, cache))
{
children[pair.Key] = pair.Value;
}


+ 6
- 3
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -33,7 +33,7 @@ public class TrackableSaver
}

private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -125,7 +125,7 @@ public class TrackableSaver
}

Dictionary<Tensor, string> feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
@@ -133,6 +133,7 @@ public class TrackableSaver

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
string file_prefix_to_save;
if (use_session)
{
if (_object_graph_feed_tensor is null)
@@ -145,16 +146,18 @@ public class TrackableSaver
object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
file_prefix_to_save = "";
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null;
file_prefix_to_save = file_prefix;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options);

if (new_feed_additions is not null)
{


+ 510
- 5
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -6,9 +6,254 @@ using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.OptimizerOptions.Types;
using static Tensorflow.Binding;
using System.Text.RegularExpressions;
using System.Linq;
using Tensorflow.Operations;
using Tensorflow.Training;
using Tensorflow.Graphs;

namespace Tensorflow.Checkpoint
{
/// <summary>
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions.
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users.
/// </summary>
public interface IFunctionHolder
{
int ArgCount { get; }
object DynamicInvoke(params object[] args);
}
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder
{
public int ArgCount => 0;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder
{
public int ArgCount => 1;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder
{
public int ArgCount => 2;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder
{
public int ArgCount => 3;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assigned = false;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assigned = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assigned = true;
}

public Type DataType => _type;

public TA GetValueA()
{
if(!_assigned || DataType != typeof(TA))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueA;
}
public TB GetValueB()
{
if (!_assigned || DataType != typeof(TB))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueB;
}
public object GetValue()
{
if (!_assigned)
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
if(DataType == typeof(TA) && _valueA is not null)
{
return _valueA;
}
else if(DataType == typeof(TB) && _valueB is not null)
{
return _valueB;
}
else if(DataType == typeof(TA))
{
return _valueA;
}
else
{
return _valueB;
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver
{
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict;
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<Tensor> tensors = new();
List<string> slice_specs = new();
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
var tensor_value = spec.tensor;
if (tensor_value is not null)
{
tensor_names.Add(spec.name);
tensors.Add(tensor_value);
slice_specs.Add(spec.slice_spec);
}
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_names.Add(checkpoint_key);
tensors.Add(tensor);
slice_specs.Add(slice_spec);
}
}
}
// TODO: specify the device.
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray());
}

public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options);

public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<TF_DataType> tensor_dtypes = new();
List<string> slice_specs = new();

foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec);
tensor_names.Add(spec.name);
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key);
}
}
}

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice_spec in tensor_slices.Keys)
{
var restored_tensor = restored_tensors[idx++];
if (!restored_tensor_dict.ContainsKey(checkpoint_key))
{
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor;
}
}
return restored_tensor_dict;
}

public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix));
}
/// <summary>
/// Saves checkpoints directly from multiple devices.
/// Note that this is a low-level utility which stores Tensors in the keys
@@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint
/// </summary>
public class MultiDeviceSaver
{
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors,
private Dictionary<string, SingleDeviceSaver> _single_device_savers;
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers;
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn;
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys;
/// <summary>
///
/// </summary>
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>();
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
var obj = pair.Key;
var tensor_dict = pair.Value;
IFunctionHolder restore_fn;
if(obj is null)
{
restore_fn = new FunctionHolder<object?>(() => null);
}
else
{
restore_fn = null;
// TODO: implement obj._restore_from_tensors
}

foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.DataType != typeof(IDictionary<string, Tensor>))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = item.Value.GetValueA();
}
else
{
spec_to_tensor = item.Value.GetValueB();
}

foreach(var spec in spec_to_tensor)
{
var slice_spec = spec.Key;
var tensor = spec.Value;
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec)))
{
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " +
$"slice spec. This is invalid because one will overwrite the " +
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation.");
}
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
}
}

_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value));

_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>();
if(registered_savers is not null && registered_savers.Count > 0)
{
// TODO: complete the implementation.
throw new NotImplementedException();
}
}

public Operation? save(string file_prefix, CheckpointOptions? options= null)
public Operation save(string file_prefix, CheckpointOptions? options= null)
{
throw new NotImplementedException();
if(options is null)
{
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
// TODO: optimize the implementation with new APIs adding to `string_ops`.
string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part";
var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix);
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));

Operation save_fn()
{
List<Tensor> saved_prefixes= new();
foreach(var saver in _registered_savers)
{
// TODO: implementi it later.
throw new NotImplementedException();
}

int num_shards = _single_device_savers.Count;
List<Operation> sharded_saves = new();
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards");
string? last_device = null;
int shard = 0;
foreach(var pair in _single_device_savers.OrderBy(x => x.Key))
{
var device = pair.Key;
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true);
}
}

if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1)
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return save_fn();
}
}

public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options);

public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}

IDictionary<string, Operation> restore_func()
{
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new();

foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key))
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
}
else
{
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}

foreach(var item in _registered_savers)
{
throw new NotImplementedException();
}
return restore_ops;
}

// TODO: complete the implementation. Currently skip it because of lack of API.
bool has_custom_device_saver = false;

if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver))
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return restore_func();
}
}

/// <summary>
/// Serializes to a SaverDef referencing the current graph.
/// </summary>
public SaverDef to_proto()
{
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename");
var save_tensor = _traced_save(filename_tensor);
var restore_op = _traced_restore(filename_tensor).op;
return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
SaveTensorName = save_tensor.name,
RestoreOpName = restore_op.name,
Version = SaverDef.Types.CheckpointFormatVersion.V2
};
}

[AutoGraph]
private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix.StringData()[0]);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
{
return array_ops.identity(file_prefix);
}
}

[AutoGraph]
private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix.StringData()[0]);
tf.device("cpu:0");
using (ops.control_dependencies(new object[] { restore_op }))
{
return array_ops.identity(file_prefix);
}
}

private static Tensor registered_saver_filename(string filename, string saver_name)
{
return tf.constant($"{filename}-{saver_name}");
}
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards)
{
throw new NotImplementedException();
return filename_tensor;
}
}
}

+ 35
- 0
src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel
{
public interface ISerializedAttributes
{
IDictionary<string, Trackable> Functions { get; }

IDictionary<string, Trackable> CheckpointableObjects { get; }

/// <summary>
/// Returns functions to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> FunctionsToSerialize { get; }

/// <summary>
/// Returns objects to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> ObjectsToSerialize{get; }

/// <summary>
/// Saves function dictionary, and validates dictionary values.
/// </summary>
/// <param name="function_dict"></param>
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict);

/// <summary>
/// Saves objects to a dictionary, and validates the values.
/// </summary>
/// <param name="object_dict"></param>
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict);
}
}

+ 2
- 1
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

@@ -24,7 +25,7 @@ namespace Tensorflow.Train
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type != SaveType.SAVEDMODEL)
{


+ 91
- 18
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -4,57 +4,130 @@ using Tensorflow.Train;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow;

public class AugmentedGraphView: ObjectGraphView
{
// private object _children_cache;
// private object _serialization_cache;
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache;
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache;
private List<string> _untraces_functions;
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions;
public AugmentedGraphView(Trackable root): base(root)
{
_untraces_functions = new();
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>();
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>();
_untraces_functions = new List<string>();
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>();
}

public void set_signature(object signature_map, object wrapped_functions)
public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions)
{
// TODO: cache
list_children(Root);
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME;
if (!_children_cache.ContainsKey(Root))
{
_children_cache[Root] = new Dictionary<string, Trackable>();
}
_children_cache[Root][name] = signature_map;
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value);
}
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL)
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
Dictionary<string, Trackable> children = new();
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL))
if(serialization_cache is not null)
{
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`.");
}

if (!_children_cache.ContainsKey(obj))
{
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>();
_children_cache[obj] = children;
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache))
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction)
{
child = maybe_uncache_variable_captures((ConcreteFunction)child);
}
children[name] = child;
}

if (obj is Function && children.Count == 0)
{
_untraces_functions.Add(((Function)obj).Name);
}
}

List<TrackableReference> res = new();
foreach(var pair in _children_cache[obj])
{
var name = pair.Name;
var child = pair.Refer;
children[name] = child;
res.Add(new TrackableReference(pair.Key, pair.Value));
}

if (obj is Function && children.Count == 0)
return res;
}

private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function)
{
if (_wrapped_functions.ContainsKey(concrete_function))
{
_untraces_functions.Add(((Function)obj).Name);
return _wrapped_functions[concrete_function];
}
// skip the process here because of lack of feature.
// In the future, we may add an attribute which could specify if the variable is supposed to be cached.
//foreach(var capture in concrete_function.CapturedInputs)
//{

return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
//}
return concrete_function;
}

public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
// TODO: implement it if needed.
Trackable get_merged_trackable(Trackable x)
{
// TODO: complete it with new definitions `Asset` and `TrackableConstant`.
return x;
}
var trackable_objects = base.breadth_first_traversal();

foreach(var obj in _children_cache.Keys)
{
// skip the deletion of cache (maybe do it later).
foreach(var pair in _children_cache[obj])
{
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
}
}

return base.breadth_first_traversal();
}

public List<(string, Trackable)> list_dependencies(Trackable obj)
{
// TODO: deal with cache.
return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList();
IDictionary<string, Trackable> children;
if (!_children_cache.ContainsKey(obj))
{
children= new Dictionary<string, Trackable>();
}
else
{
children= _children_cache[obj];
}
List<(string, Trackable)> res = new();
foreach(var pair in obj.deserialization_dependencies(children))
{
res.Add((pair.Key, pair.Value));
}
return res;
}

public Trackable get_child(Trackable obj, string name)
{
throw new NotImplementedException();
return _children_cache[obj][name];
}
}

+ 3
- 3
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -141,16 +141,16 @@ public class SaveableView
foreach (var node in _nodes)
{
var node_id = _node_ids[node];
List<int> deps = new();
List<int> deps = new List<int>();
dependency_map.Add(node_id, deps);
// TODO: deal with captured tensor.

string node_path;
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node))
{
if (!_node_ids.ContainsKey(dep))
{
node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
throw new ValueError(
$"Found an untracked dependency. Object {node_path} depends on {dep}, " +
$"but this dependency isn't listed as a child. Please track this child by " +


+ 8
- 8
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -24,7 +24,7 @@ public static partial class SavedModelUtils
}.Select(x => (int)x);
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj,
string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
{
if (options is null)
{
@@ -41,9 +41,9 @@ public static partial class SavedModelUtils

if (!experimental_skip_checkpoint)
{
Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir);
SavedModelUtils.get_or_create_variables_dir(export_dir);
CheckpointOptions ckpt_options = new(options.experimental_io_device);
object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
}
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir);

@@ -67,7 +67,7 @@ public static partial class SavedModelUtils
}

var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB));
File.WriteAllText(path, saved_model.ToString());
File.WriteAllBytes(path, saved_model.ToByteArray());

if (options.save_debug_info)
{
@@ -81,7 +81,7 @@ public static partial class SavedModelUtils

private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>,
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{
if (ops.inside_function())
{
@@ -95,9 +95,9 @@ public static partial class SavedModelUtils
}

AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is not null)
if (signatures is null)
{
throw new NotImplementedException();
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view);
}
// TODO: process of aignatures and wrapped_functions
@@ -125,7 +125,7 @@ public static partial class SavedModelUtils
}

private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view,
IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist,
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist,
bool save_custom_gradients)
{
var resource_initializers = saveable_view.get_concrete_resource_initializers();


+ 74
- 25
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -1,15 +1,84 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow;

public static class SignatureSerializationUtils
{
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature";
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures";
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5;
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
Debug.Assert(func is ConcreteFunction);
// TODO: assert the `func.structured_outputs` and arg_keywords.
signature_map._add_signature(name, (ConcreteFunction)func);
}

return signature_map;
}

public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view)
{
var children = graph_view.list_children(graph_view.Root);
List<Trackable> possible_signatures = new();
foreach (var item in children)
{
var name = item.Name;
var child = item.Refer;
if(child is not (Function or ConcreteFunction))
{
continue;
}
if(name == DEFAULT_SIGNATURE_ATTR)
{
Debug.Assert(child is ConcreteFunction);
return (ConcreteFunction)child;
}
ConcreteFunction concrete = get_signature(child);
if(concrete is not null && valid_signature(concrete))
{
possible_signatures.Add(concrete);
}
}

if(possible_signatures.Count == 1)
{
var signature = get_signature(possible_signatures[0]);
if(signature is not null && valid_signature(signature))
{
return signature;
}
}
return null;
}

private static ConcreteFunction get_signature(Trackable function)
{
// TODO: implement it.
return null;
}

private static bool valid_signature(ConcreteFunction concreate_function)
{
// TODO: implement it.
return false;
}
}

public class SignatureMap: Trackable
{
private Dictionary<string, Function> _signatures;
private Dictionary<string, ConcreteFunction> _concrete_signatures;
private Dictionary<string, Trackable> _signatures;

public SignatureMap()
{
@@ -18,7 +87,7 @@ public class SignatureMap: Trackable

public void _add_signature(string name, ConcreteFunction concrete_function)
{
_concrete_signatures[name] = concrete_function;
_signatures[name] = concrete_function;
}
public void _add_signature(string name, Function concrete_function)
@@ -26,33 +95,13 @@ public class SignatureMap: Trackable
_signatures[name] = concrete_function;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if (save_type != SaveType.SAVEDMODEL)
{
return new Dictionary<string, Trackable>();
}

Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value);
foreach (var pair in _concrete_signatures)
{
res[pair.Key] = pair.Value;
}

return res;
}

public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
// TODO: assert the arg_keywords
signature_map._add_signature(name, func);
}

return signature_map;
return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
}
}

+ 142
- 14
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -16,18 +16,38 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow
{
public static class saveable_object_util
/// <summary>
/// A SaveableObject that defines `Trackable` checkpointing steps.
/// </summary>
public class TrackableSaveable : MySaveableObject
{
public class TrackableSaveable: MySaveableObject
private string _prefix;
private IEnumerable<string> _local_names;
private Trackable _trackable;
private bool _call_with_mapped_captures;
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict.
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names,
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name)
{
_prefix = prefix;
_trackable = obj;
_local_names = local_names;
_call_with_mapped_captures = call_with_mapped_captures;
}

// TODO: complete this class.
}
public static class saveable_object_util
{
/// <summary>
/// Returns the variables and names that will be used for a Saver.
/// </summary>
@@ -57,7 +77,7 @@ namespace Tensorflow
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
@@ -79,6 +99,74 @@ namespace Tensorflow
}
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name)
{
// The `op` maybe `Variable` or `Trackable`.
if (obj is BaseResourceVariable)
{
var variable = obj as BaseResourceVariable;
if (variable.InGraphMode)
{
yield return new ResourceVariableSaveable(variable.GraphElement, "", name);
}
else
{
Debug.Assert(variable is ResourceVariable);
yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name);
}
}
else
{
foreach(var pair in saveable_objects_from_trackable(obj))
{
var attr = pair.Key;
var factory = pair.Value;
string full_name;
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY)
{
full_name = name;
}
else
{
full_name = name + "_" + attr;
}
if(factory.DataType == typeof(ResourceVariable))
{
var variable = factory.GetValueA();
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
{
yield return op;
}
}
else
{
var variable = factory.GetValueB();
foreach (var op in saveable_objects_for_op(variable, variable.name))
{
yield return op;
}
}
}
}
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name)
{
yield return obj;
}

public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true)
{
op_list = op_list.OrderBy(x => x.Name).ToArray();
@@ -127,16 +215,55 @@ namespace Tensorflow
return names_to_saveables;
}

public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
{
// TODO: complete the implementation.
return obj.gather_saveables_for_checkpoint();
// skip the process of type `PythonState`

if (trackable_has_serialize_to_tensor(obj))
{
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME;
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors();

List<SaveSpec> specs = new();
List<string> local_names = new();
string prefix = SaveableCompat.get_saveable_name(obj) ?? "";
foreach(var pair in tensor_dict)
{
var tensor_name = pair.Key;
var maybe_tensor = pair.Value;
local_names.Add(tensor_name);
string spec_name = name + TrackableUtils.escape_local_name(tensor_name);

IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.DataType == typeof(Tensor))
{
internal_dict= new Dictionary<string, Tensor>();
internal_dict[""] = maybe_tensor.GetValueA();
}
else
{
internal_dict = maybe_tensor.GetValueB();
}

foreach(var item in internal_dict)
{
specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
}
}
Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return res;
}
else
{
return obj.gather_saveables_for_checkpoint();
}
}

public static bool trackable_has_serialize_to_tensor(Trackable obj)
{
// TODO: implement it.
return false;
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable);
}

internal static string convert_to_string(string x)
@@ -158,27 +285,28 @@ namespace Tensorflow
public Trackable Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, object> serialize_to_tensors()
public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
return saveable_objects_to_tensor_dict(_saveables);
return saveable_object_to_tensor_dict(_saveables);
}

/// <summary>
/// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary>
/// <param name="saveables"></param>
public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables)
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
{
Dictionary<string, object> tensor_dict = new();
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach (var saveable in saveables)
{
foreach(var spec in saveable.specs)
{
// skip the check that if `spec` is callable.
var name = saveable_object_util.convert_to_string(spec.name);
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec))
{
throw new NotImplementedException();
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor;
}
else
{


+ 33
- 15
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -16,7 +16,10 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;
using Tensorflow.Training;
using static Tensorflow.Binding;
@@ -39,8 +42,8 @@ namespace Tensorflow.Train

protected IList<TrackableReference> _unconditional_checkpoint_dependencies;

protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
new Dictionary<string, ResourceVariable>();
protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
private bool _manual_tracking = true;

private static Trackable _none = new Function();
@@ -94,9 +97,13 @@ namespace Tensorflow.Train
// assign again. It will add this variable to our dependencies, and if there
// is a non-trivial restoration queued, it will handle that. This also
// handles slot variables.
if (!args.Overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: args.Name,
overwrite: args.Overwrite);
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}
else
return new_variable;
}
@@ -122,13 +129,16 @@ namespace Tensorflow.Train
/// </summary>
public void _maybe_initialize_trackable()
{
if(_unconditional_checkpoint_dependencies is not null)
{
return;
}
_self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>();
}

// TODO: cache
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache)
{
_maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
@@ -139,8 +149,8 @@ namespace Tensorflow.Train
_maybe_initialize_trackable();
if (!_manual_tracking) return trackable;
var new_reference = new TrackableReference(name, trackable);
var current_object = _lookupup_dependency(name);
var current_object = _lookup_dependency(name);
if(current_object is null)
{
_unconditional_checkpoint_dependencies.Add(new_reference);
@@ -170,7 +180,7 @@ namespace Tensorflow.Train
// TODO: complete the implementation.
}

public virtual Trackable? _lookupup_dependency(string name)
public virtual Trackable? _lookup_dependency(string name)
{
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency;
else return null;
@@ -199,8 +209,8 @@ namespace Tensorflow.Train
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>());
}

public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null,
IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null)
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map,
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null)
{
var (self_object_map, self_tensor_map) = map_resources(options);
foreach (var pair in self_object_map)
@@ -215,9 +225,17 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList();
}

public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
return _self_saveable_object_factories;
if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
throw new NotImplementedException();
}
else
{
return _self_saveable_object_factories;
}
}

/// <summary>
@@ -229,7 +247,7 @@ namespace Tensorflow.Train
/// </summary>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, object> serialize_to_tensors()
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
throw new NotImplementedException();
}


+ 26
- 2
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;
@@ -22,7 +23,7 @@ public static class TrackableUtils
private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name)));
@@ -145,4 +146,27 @@ public static class TrackableUtils
return $"root.{string.Join(".", paths.Select(x => x.Name))}";
}
}

/// <summary>
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key.
/// </summary>
/// <param name="key"></param>
/// <param name="prefix"></param>
/// <returns></returns>
public static string extract_local_name(string key, string? prefix = null)
{
if(prefix is null)
{
prefix = "";
}
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix;
try
{
return key.Substring(key.IndexOf(search_key) + search_key.Length);
}
catch(ArgumentOutOfRangeException)
{
return key;
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -9,6 +9,7 @@ using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Keras;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;
@@ -243,7 +244,7 @@ namespace Tensorflow.Training
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
check_external_modification();
if (_non_append_mutation_value)


+ 3
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -4,6 +4,8 @@ using Tensorflow.Eager;
using Tensorflow.Variables;
using Tensorflow.Train;
using static Tensorflow.Binding;
using System.Collections.Generic;
using Tensorflow.ModelSaving;

namespace Tensorflow
{
@@ -20,6 +22,7 @@ namespace Tensorflow
public string UniqueId => _unique_id;

protected bool _in_graph_mode;
internal bool InGraphMode => _in_graph_mode;

protected bool _trainable;
public bool Trainable => _trainable;


+ 9
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -17,7 +17,9 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using Tensorflow.Checkpoint;
using Tensorflow.NumPy;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -235,5 +237,12 @@ namespace Tensorflow
{
return _graph_element.eval(session);
}

public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
return res;
}
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;
@@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine
return output_tensors;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value);


+ 4
- 3
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;
@@ -9,16 +10,16 @@ public abstract partial class Layer
{
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);

public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;

public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
IDictionary<string, Trackable> children;
if (save_type == SaveType.SAVEDMODEL)
{
// TODO: deal with cache.
Debug.Assert(cache is not null);
children = TrackableSavedModelSaver.trackable_children(cache);
}
else


+ 22
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine

ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;
public Tensor[] input
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].input_tensors;
}
return null;
}
}
public Dictionary<int, List<INode>> NodesByDepth { get; set; }
public Shape OutputShape => inboundNodes[0].Outputs.shape;
public Shape OutputShape
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].Outputs.shape;
}
return null;
}
}
protected List<ILayer> _self_tracked_trackables;

public Layer(LayerArgs args)


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
bool include_optimizer = true,
string save_format = "tf",
SaveOptions? options = null,
IDictionary<string, ConcreteFunction>? signatures = null,
ConcreteFunction? signatures = null,
bool save_traces = true)
{
if (save_format != "pb")


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type == SaveType.SAVEDMODEL)
{


+ 2
- 7
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures,
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
SaveOptions? options, bool save_traces = true)
{
if (!overwrite && File.Exists(filepath))
@@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils
}

var metadata = generate_keras_metadata(saved_nodes, node_paths);
using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate,
FileAccess.Write))
{
var writer = new StreamWriter(f);
writer.Write(metadata.ToString());
}
File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray());

if (!include_optimizer)
{


+ 2
- 2
src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs View File

@@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache)
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs.

@@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils
/// <param name="layer"></param>
/// <param name="serialization_cache"></param>
/// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache)
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with type `RevivedLayer` and `Sequential`.



+ 3
- 4
src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs View File

@@ -18,12 +18,12 @@ public abstract class SavedModelSaver
public abstract string TrackingMetadata { get; }

public abstract IDictionary<string, Trackable> objects_to_serialize(
IDictionary<string, object> serialization_cache);
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache);

public abstract IDictionary<string, Trackable> functions_to_serialize(
IDictionary<string, object> serialization_cache);
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache);

public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache)
public IDictionary<string, Trackable> trackable_children(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
if (!KerasSavedModelUtils.ShouldHaveTraces)
{
@@ -31,7 +31,6 @@ public abstract class SavedModelSaver
}

var children = objects_to_serialize(serialization_cache);

return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value))
.ToDictionary(x => x.Key, x => x.Value);
}


+ 20
- 8
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver
get => Constants.LAYER_IDENTIFIER;
}

public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache)
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
return get_serialized_attributes(serialization_cache).ObjectsToSerialize;
}

public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache)
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
return get_serialized_attributes(serialization_cache).FunctionsToSerialize;
}
@@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver
/// Generates or retrieves serialized attributes from cache.
/// </summary>
/// <param name="serialization_cache"></param>
protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache)
protected ISerializedAttributes get_serialized_attributes(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
// TODO: deal with cache.
IDictionary<Trackable, ISerializedAttributes> keras_cache;
if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY))
{
keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY];
}
else
{
serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary<Trackable, ISerializedAttributes>();
}
if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj];

var serialized_attr = SerializedAttributes.Create(_obj);
var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj);

// TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`.
if (KerasSavedModelUtils.should_skip_serialization(_obj))
@@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver
/// Returns dictionary of serialized attributes.
/// </summary>
/// <param name="serialization_cache"></param>
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache)
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache);
var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache);
@@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver
metadata["trainable"] = _obj.Trainable;
// metadata["expects_training_arg"] = _obj._expects_training_arg;
// metadata["dtype"] = policy.serialize(_obj._dtype_policy)
metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape);
metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape);
// metadata["stateful"] = _obj.stateful;
// metadata["must_restore_from_config"] = _obj.must_restore_from_config;
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
@@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver
}
}

public static LayerConfig get_serialized(Layer obj)
public static IDictionary<string, object> get_serialized(Layer obj)
{
return generic_utils.serialize_keras_object(obj);
// TODO: complete the implmentation (need to revise `get_config`).
return new Dictionary<string, object>();
//return generic_utils.serialize_keras_object(obj);
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
/// <summary>
/// Class that tracks and validates all serialization attributes.
/// </summary>
public abstract class SerializedAttributes
public abstract class SerializedAttributes: ISerializedAttributes
{
protected IDictionary<string, Trackable?> _object_dict;
protected IDictionary<string, Trackable?> _function_dict;


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/SaveTest.cs View File

@@ -50,11 +50,11 @@ public class SaveTest
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
ValidationSize = 0,
}).Result;
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.save("", save_format:"pb");
model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb");
}
}

Loading…
Cancel
Save