Browse Source

Add lacked implementations (mainly MultiDeviceSaver).

pull/976/head
AsakusaRinne 2 years ago
parent
commit
83906b8f79
30 changed files with 1037 additions and 161 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  2. +5
    -4
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  3. +12
    -11
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  4. +11
    -16
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  5. +3
    -2
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  6. +6
    -3
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  7. +510
    -5
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  8. +35
    -0
      src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs
  9. +2
    -1
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  10. +91
    -18
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs
  11. +3
    -3
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  12. +8
    -8
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  13. +74
    -25
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  14. +142
    -14
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  15. +33
    -15
      src/TensorFlowNET.Core/Training/Trackable.cs
  16. +26
    -2
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  17. +2
    -1
      src/TensorFlowNET.Core/Training/data_structures.cs
  18. +3
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  19. +9
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  20. +2
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  21. +4
    -3
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  22. +22
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  23. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  24. +2
    -1
      src/TensorFlowNET.Keras/Engine/Model.cs
  25. +2
    -7
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  26. +2
    -2
      src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs
  27. +3
    -4
      src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs
  28. +20
    -8
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  30. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/SaveTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -1,5 +1,5 @@
namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


public record class CheckpointOptions( public record class CheckpointOptions(
string experimental_io_device = null,
string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false); bool experimental_enable_async_checkpoint = false);

+ 5
- 4
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Serilog.Debugging; using Serilog.Debugging;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train; using Tensorflow.Train;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;
@@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable
return new ObjectGraphView(Root, _attached_dependencies); return new ObjectGraphView(Root, _attached_dependencies);
} }


public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{ {
List<TrackableReference> res = base.children(obj, save_type)
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); .Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value. // Check the reference, not value.
if (obj == Root && _attached_dependencies is not null) if (obj == Root && _attached_dependencies is not null)
@@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable
return res; return res;
} }
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{ {
return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer);
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
} }
public IEnumerable<TrackableReference>? AttachedDependencies public IEnumerable<TrackableReference>? AttachedDependencies


+ 12
- 11
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint
); );
public static class SaveUtil public static class SaveUtil
{ {
public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{ {
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param> /// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param> /// <param name="cache"></param>
/// <param name="object_graph_proto"></param> /// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{ {
Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach(var td in tensor_trackables) foreach(var td in tensor_trackables)
{ {
// TODO: deal with cache. // TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
var trackable = td.object_to_save; var trackable = td.object_to_save;
IDictionary<string, object> tensor_dict;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{ {
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint
return serialized_tensors; return serialized_tensors;
} }


private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{ {
var trackable = trackable_data.object_to_save; var trackable = trackable_data.object_to_save;


// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, object> ret_tensor_dict;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
if (call_with_mapped_captures) if (call_with_mapped_captures)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
@@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint
ret_tensor_dict = trackable.serialize_to_tensors(); ret_tensor_dict = trackable.serialize_to_tensors();
} }


// TODO: revise the types and complete it
Dictionary<string, object> tensor_dict = new();
// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict) foreach(var pair in ret_tensor_dict)
{ {
var local_name = TrackableUtils.escape_local_name(pair.Key); var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint


tensor_dict[checkpoint_key] = maybe_tensor; tensor_dict[checkpoint_key] = maybe_tensor;


if(maybe_tensor is SaveSpec)
if(maybe_tensor.GetValueA() is SaveSpec)
{ {
((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
} }


if(object_graph_proto is not null) if(object_graph_proto is not null)
@@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param> /// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param> /// <param name="object_graph_proto"></param>
/// <returns></returns> /// <returns></returns>
private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{ {
Dictionary<Trackable, string> object_names = new(); Dictionary<Trackable, string> object_names = new();


+ 11
- 16
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -174,25 +174,20 @@ public static class SaveUtilV1
{ {
var name = factory_data.name; var name = factory_data.name;
var key = factory_data.checkpoint_key; var key = factory_data.checkpoint_key;
var saveable_factory = factory_data.factory;
var maybe_saveable = factory_data.factory;
// TODO: oneflow python has a process with callable `saveable_factory`. // TODO: oneflow python has a process with callable `saveable_factory`.
var maybe_saveable = saveable_factory;
IEnumerable<MySaveableObject> savesbles;
if (maybe_saveable is MySaveableObject)
{
savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable };
}
else if (maybe_saveable is Tensor)
List<MySaveableObject> saveables = new();
if (maybe_saveable.DataType == typeof(MySaveableObject))
{ {
savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key);
saveables.Add(maybe_saveable.GetValueB());
} }
else else
{ {
throw new TypeError("Unexpected type.");
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key));
} }


foreach (var saveable in savesbles)
foreach (var saveable in saveables)
{ {
if (!saveable.name.Contains(key)) if (!saveable.name.Contains(key))
{ {
@@ -204,11 +199,11 @@ public static class SaveUtilV1
// skip the process of PythonState // skip the process of PythonState
named_saveable_objects.AddRange(savesbles);
named_saveable_objects.AddRange(saveables);
if(!fill_object_proto) continue; if(!fill_object_proto) continue;
// skip the process of TrackableSaveable
// skip the process of `TrackableSaveable` because of lack of APIs.


object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
@@ -221,7 +216,7 @@ public static class SaveUtilV1


public record class CheckpointFactoryData public record class CheckpointFactoryData
( (
object factory,
Maybe<ResourceVariable, MySaveableObject> factory,
string name, string name,
string checkpoint_key string checkpoint_key
); );

+ 3
- 2
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Train; using Tensorflow.Train;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using Tensorflow.Keras.Saving.SavedModel;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


@@ -18,13 +19,13 @@ public class TrackableView
_root_ref = obj; _root_ref = obj;
} }
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
obj._maybe_initialize_trackable(); obj._maybe_initialize_trackable();
Dictionary<string, Trackable> children = new(); Dictionary<string, Trackable> children = new();
// Note: in python the return type of `Trackable._trackable_children` is not fixed. // Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process. // Therefore it uses `convert_to_trackable` to have an extra process.
foreach (var pair in obj._trackable_children(save_type))
foreach (var pair in obj._trackable_children(save_type, cache))
{ {
children[pair.Key] = pair.Value; children[pair.Key] = pair.Value;
} }


+ 6
- 3
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -33,7 +33,7 @@ public class TrackableSaver
} }


private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null) gather_serialized_tensors(Tensor? object_graph_tensor = null)
{ {
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -125,7 +125,7 @@ public class TrackableSaver
} }


Dictionary<Tensor, string> feed_dict = new(); Dictionary<Tensor, string> feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null) if (checkpoint_number is not null)
{ {
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
@@ -133,6 +133,7 @@ public class TrackableSaver


Tensor file_prefix_tensor; Tensor file_prefix_tensor;
Tensor object_graph_tensor; Tensor object_graph_tensor;
string file_prefix_to_save;
if (use_session) if (use_session)
{ {
if (_object_graph_feed_tensor is null) if (_object_graph_feed_tensor is null)
@@ -145,16 +146,18 @@ public class TrackableSaver
object_graph_tensor = _object_graph_feed_tensor; object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor; file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix; feed_dict[file_prefix_tensor] = file_prefix;
file_prefix_to_save = "";
} }
else else
{ {
// In python there is `with ops.device("/cpu:0")`. // In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null; object_graph_tensor = null;
file_prefix_to_save = file_prefix;
} }


var (save_path, new_feed_additions) = var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options);


if (new_feed_additions is not null) if (new_feed_additions is not null)
{ {


+ 510
- 5
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -6,9 +6,254 @@ using Tensorflow.Train;
using static Tensorflow.ApiDef.Types; using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types; using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.OptimizerOptions.Types; using static Tensorflow.OptimizerOptions.Types;
using static Tensorflow.Binding;
using System.Text.RegularExpressions;
using System.Linq;
using Tensorflow.Operations;
using Tensorflow.Training;
using Tensorflow.Graphs;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
/// <summary>
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions.
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users.
/// </summary>
public interface IFunctionHolder
{
int ArgCount { get; }
object DynamicInvoke(params object[] args);
}
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder
{
public int ArgCount => 0;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder
{
public int ArgCount => 1;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder
{
public int ArgCount => 2;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder
{
public int ArgCount => 3;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assigned = false;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assigned = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assigned = true;
}

public Type DataType => _type;

public TA GetValueA()
{
if(!_assigned || DataType != typeof(TA))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueA;
}
public TB GetValueB()
{
if (!_assigned || DataType != typeof(TB))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueB;
}
public object GetValue()
{
if (!_assigned)
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
if(DataType == typeof(TA) && _valueA is not null)
{
return _valueA;
}
else if(DataType == typeof(TB) && _valueB is not null)
{
return _valueB;
}
else if(DataType == typeof(TA))
{
return _valueA;
}
else
{
return _valueB;
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver
{
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict;
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<Tensor> tensors = new();
List<string> slice_specs = new();
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
var tensor_value = spec.tensor;
if (tensor_value is not null)
{
tensor_names.Add(spec.name);
tensors.Add(tensor_value);
slice_specs.Add(spec.slice_spec);
}
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_names.Add(checkpoint_key);
tensors.Add(tensor);
slice_specs.Add(slice_spec);
}
}
}
// TODO: specify the device.
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray());
}

public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options);

public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<TF_DataType> tensor_dtypes = new();
List<string> slice_specs = new();

foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec);
tensor_names.Add(spec.name);
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key);
}
}
}

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice_spec in tensor_slices.Keys)
{
var restored_tensor = restored_tensors[idx++];
if (!restored_tensor_dict.ContainsKey(checkpoint_key))
{
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor;
}
}
return restored_tensor_dict;
}

public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix));
}
/// <summary> /// <summary>
/// Saves checkpoints directly from multiple devices. /// Saves checkpoints directly from multiple devices.
/// Note that this is a low-level utility which stores Tensors in the keys /// Note that this is a low-level utility which stores Tensors in the keys
@@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint
/// </summary> /// </summary>
public class MultiDeviceSaver public class MultiDeviceSaver
{ {
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors,
private Dictionary<string, SingleDeviceSaver> _single_device_savers;
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers;
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn;
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys;
/// <summary>
///
/// </summary>
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{ {
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>();
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
var obj = pair.Key;
var tensor_dict = pair.Value;
IFunctionHolder restore_fn;
if(obj is null)
{
restore_fn = new FunctionHolder<object?>(() => null);
}
else
{
restore_fn = null;
// TODO: implement obj._restore_from_tensors
}

foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.DataType != typeof(IDictionary<string, Tensor>))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = item.Value.GetValueA();
}
else
{
spec_to_tensor = item.Value.GetValueB();
}

foreach(var spec in spec_to_tensor)
{
var slice_spec = spec.Key;
var tensor = spec.Value;
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec)))
{
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " +
$"slice spec. This is invalid because one will overwrite the " +
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation.");
}
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
}
}

_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value));


_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>();
if(registered_savers is not null && registered_savers.Count > 0)
{
// TODO: complete the implementation.
throw new NotImplementedException();
}
} }


public Operation? save(string file_prefix, CheckpointOptions? options= null)
public Operation save(string file_prefix, CheckpointOptions? options= null)
{ {
throw new NotImplementedException();
if(options is null)
{
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
// TODO: optimize the implementation with new APIs adding to `string_ops`.
string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part";
var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix);
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));

Operation save_fn()
{
List<Tensor> saved_prefixes= new();
foreach(var saver in _registered_savers)
{
// TODO: implementi it later.
throw new NotImplementedException();
}

int num_shards = _single_device_savers.Count;
List<Operation> sharded_saves = new();
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards");
string? last_device = null;
int shard = 0;
foreach(var pair in _single_device_savers.OrderBy(x => x.Key))
{
var device = pair.Key;
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true);
}
}

if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1)
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return save_fn();
}
} }


public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options);

public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}

IDictionary<string, Operation> restore_func()
{
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new();

foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key))
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
}
else
{
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}

foreach(var item in _registered_savers)
{
throw new NotImplementedException();
}
return restore_ops;
}

// TODO: complete the implementation. Currently skip it because of lack of API.
bool has_custom_device_saver = false;

if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver))
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return restore_func();
}
}

/// <summary>
/// Serializes to a SaverDef referencing the current graph.
/// </summary>
public SaverDef to_proto()
{
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename");
var save_tensor = _traced_save(filename_tensor);
var restore_op = _traced_restore(filename_tensor).op;
return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
SaveTensorName = save_tensor.name,
RestoreOpName = restore_op.name,
Version = SaverDef.Types.CheckpointFormatVersion.V2
};
}

[AutoGraph]
private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix.StringData()[0]);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
{
return array_ops.identity(file_prefix);
}
}

[AutoGraph]
private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix.StringData()[0]);
tf.device("cpu:0");
using (ops.control_dependencies(new object[] { restore_op }))
{
return array_ops.identity(file_prefix);
}
}

private static Tensor registered_saver_filename(string filename, string saver_name)
{
return tf.constant($"{filename}-{saver_name}");
}
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards)
{ {
throw new NotImplementedException();
return filename_tensor;
} }
} }
} }

+ 35
- 0
src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel
{
public interface ISerializedAttributes
{
IDictionary<string, Trackable> Functions { get; }

IDictionary<string, Trackable> CheckpointableObjects { get; }

/// <summary>
/// Returns functions to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> FunctionsToSerialize { get; }

/// <summary>
/// Returns objects to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> ObjectsToSerialize{get; }

/// <summary>
/// Saves function dictionary, and validates dictionary values.
/// </summary>
/// <param name="function_dict"></param>
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict);

/// <summary>
/// Saves objects to a dictionary, and validates the values.
/// </summary>
/// <param name="object_dict"></param>
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict);
}
}

+ 2
- 1
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation; using Tensorflow.Operations.Activation;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -24,7 +25,7 @@ namespace Tensorflow.Train
} }
} }


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
if(save_type != SaveType.SAVEDMODEL) if(save_type != SaveType.SAVEDMODEL)
{ {


+ 91
- 18
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -4,57 +4,130 @@ using Tensorflow.Train;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;


namespace Tensorflow; namespace Tensorflow;


public class AugmentedGraphView: ObjectGraphView public class AugmentedGraphView: ObjectGraphView
{ {
// private object _children_cache;
// private object _serialization_cache;
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache;
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache;
private List<string> _untraces_functions; private List<string> _untraces_functions;
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions;
public AugmentedGraphView(Trackable root): base(root) public AugmentedGraphView(Trackable root): base(root)
{ {
_untraces_functions = new();
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>();
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>();
_untraces_functions = new List<string>();
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>();
} }


public void set_signature(object signature_map, object wrapped_functions)
public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions)
{ {
// TODO: cache
list_children(Root); list_children(Root);
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME;
if (!_children_cache.ContainsKey(Root))
{
_children_cache[Root] = new Dictionary<string, Trackable>();
}
_children_cache[Root][name] = signature_map;
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value);
} }
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL)
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{ {
Dictionary<string, Trackable> children = new();
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL))
if(serialization_cache is not null)
{
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`.");
}

if (!_children_cache.ContainsKey(obj))
{
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>();
_children_cache[obj] = children;
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache))
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction)
{
child = maybe_uncache_variable_captures((ConcreteFunction)child);
}
children[name] = child;
}

if (obj is Function && children.Count == 0)
{
_untraces_functions.Add(((Function)obj).Name);
}
}

List<TrackableReference> res = new();
foreach(var pair in _children_cache[obj])
{ {
var name = pair.Name;
var child = pair.Refer;
children[name] = child;
res.Add(new TrackableReference(pair.Key, pair.Value));
} }


if (obj is Function && children.Count == 0)
return res;
}

private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function)
{
if (_wrapped_functions.ContainsKey(concrete_function))
{ {
_untraces_functions.Add(((Function)obj).Name);
return _wrapped_functions[concrete_function];
} }
// skip the process here because of lack of feature.
// In the future, we may add an attribute which could specify if the variable is supposed to be cached.
//foreach(var capture in concrete_function.CapturedInputs)
//{


return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
//}
return concrete_function;
} }


public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{ {
// TODO: implement it if needed.
Trackable get_merged_trackable(Trackable x)
{
// TODO: complete it with new definitions `Asset` and `TrackableConstant`.
return x;
}
var trackable_objects = base.breadth_first_traversal();

foreach(var obj in _children_cache.Keys)
{
// skip the deletion of cache (maybe do it later).
foreach(var pair in _children_cache[obj])
{
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
}
}

return base.breadth_first_traversal(); return base.breadth_first_traversal();
} }


public List<(string, Trackable)> list_dependencies(Trackable obj) public List<(string, Trackable)> list_dependencies(Trackable obj)
{ {
// TODO: deal with cache.
return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList();
IDictionary<string, Trackable> children;
if (!_children_cache.ContainsKey(obj))
{
children= new Dictionary<string, Trackable>();
}
else
{
children= _children_cache[obj];
}
List<(string, Trackable)> res = new();
foreach(var pair in obj.deserialization_dependencies(children))
{
res.Add((pair.Key, pair.Value));
}
return res;
} }


public Trackable get_child(Trackable obj, string name) public Trackable get_child(Trackable obj, string name)
{ {
throw new NotImplementedException();
return _children_cache[obj][name];
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -141,16 +141,16 @@ public class SaveableView
foreach (var node in _nodes) foreach (var node in _nodes)
{ {
var node_id = _node_ids[node]; var node_id = _node_ids[node];
List<int> deps = new();
List<int> deps = new List<int>();
dependency_map.Add(node_id, deps);
// TODO: deal with captured tensor. // TODO: deal with captured tensor.


string node_path;
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node))
{ {
if (!_node_ids.ContainsKey(dep)) if (!_node_ids.ContainsKey(dep))
{ {
node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
throw new ValueError( throw new ValueError(
$"Found an untracked dependency. Object {node_path} depends on {dep}, " + $"Found an untracked dependency. Object {node_path} depends on {dep}, " +
$"but this dependency isn't listed as a child. Please track this child by " + $"but this dependency isn't listed as a child. Please track this child by " +


+ 8
- 8
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -24,7 +24,7 @@ public static partial class SavedModelUtils
}.Select(x => (int)x); }.Select(x => (int)x);
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj,
string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
{ {
if (options is null) if (options is null)
{ {
@@ -41,9 +41,9 @@ public static partial class SavedModelUtils


if (!experimental_skip_checkpoint) if (!experimental_skip_checkpoint)
{ {
Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir);
SavedModelUtils.get_or_create_variables_dir(export_dir);
CheckpointOptions ckpt_options = new(options.experimental_io_device); CheckpointOptions ckpt_options = new(options.experimental_io_device);
object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
} }
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir);


@@ -67,7 +67,7 @@ public static partial class SavedModelUtils
} }


var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB));
File.WriteAllText(path, saved_model.ToString());
File.WriteAllBytes(path, saved_model.ToByteArray());


if (options.save_debug_info) if (options.save_debug_info)
{ {
@@ -81,7 +81,7 @@ public static partial class SavedModelUtils


private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>,
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{ {
if (ops.inside_function()) if (ops.inside_function())
{ {
@@ -95,9 +95,9 @@ public static partial class SavedModelUtils
} }


AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is not null)
if (signatures is null)
{ {
throw new NotImplementedException();
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view);
} }
// TODO: process of aignatures and wrapped_functions // TODO: process of aignatures and wrapped_functions
@@ -125,7 +125,7 @@ public static partial class SavedModelUtils
} }


private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view,
IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist,
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist,
bool save_custom_gradients) bool save_custom_gradients)
{ {
var resource_initializers = saveable_view.get_concrete_resource_initializers(); var resource_initializers = saveable_view.get_concrete_resource_initializers();


+ 74
- 25
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -1,15 +1,84 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train; using Tensorflow.Train;


namespace Tensorflow; namespace Tensorflow;


public static class SignatureSerializationUtils
{
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature";
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures";
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5;
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
Debug.Assert(func is ConcreteFunction);
// TODO: assert the `func.structured_outputs` and arg_keywords.
signature_map._add_signature(name, (ConcreteFunction)func);
}

return signature_map;
}

public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view)
{
var children = graph_view.list_children(graph_view.Root);
List<Trackable> possible_signatures = new();
foreach (var item in children)
{
var name = item.Name;
var child = item.Refer;
if(child is not (Function or ConcreteFunction))
{
continue;
}
if(name == DEFAULT_SIGNATURE_ATTR)
{
Debug.Assert(child is ConcreteFunction);
return (ConcreteFunction)child;
}
ConcreteFunction concrete = get_signature(child);
if(concrete is not null && valid_signature(concrete))
{
possible_signatures.Add(concrete);
}
}

if(possible_signatures.Count == 1)
{
var signature = get_signature(possible_signatures[0]);
if(signature is not null && valid_signature(signature))
{
return signature;
}
}
return null;
}

private static ConcreteFunction get_signature(Trackable function)
{
// TODO: implement it.
return null;
}

private static bool valid_signature(ConcreteFunction concreate_function)
{
// TODO: implement it.
return false;
}
}

public class SignatureMap: Trackable public class SignatureMap: Trackable
{ {
private Dictionary<string, Function> _signatures;
private Dictionary<string, ConcreteFunction> _concrete_signatures;
private Dictionary<string, Trackable> _signatures;


public SignatureMap() public SignatureMap()
{ {
@@ -18,7 +87,7 @@ public class SignatureMap: Trackable


public void _add_signature(string name, ConcreteFunction concrete_function) public void _add_signature(string name, ConcreteFunction concrete_function)
{ {
_concrete_signatures[name] = concrete_function;
_signatures[name] = concrete_function;
} }
public void _add_signature(string name, Function concrete_function) public void _add_signature(string name, Function concrete_function)
@@ -26,33 +95,13 @@ public class SignatureMap: Trackable
_signatures[name] = concrete_function; _signatures[name] = concrete_function;
} }


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
if (save_type != SaveType.SAVEDMODEL) if (save_type != SaveType.SAVEDMODEL)
{ {
return new Dictionary<string, Trackable>(); return new Dictionary<string, Trackable>();
} }


Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value);
foreach (var pair in _concrete_signatures)
{
res[pair.Key] = pair.Value;
}

return res;
}

public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
// TODO: assert the arg_keywords
signature_map._add_signature(name, func);
}

return signature_map;
return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
} }
} }

+ 142
- 14
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -16,18 +16,38 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public static class saveable_object_util
/// <summary>
/// A SaveableObject that defines `Trackable` checkpointing steps.
/// </summary>
public class TrackableSaveable : MySaveableObject
{ {
public class TrackableSaveable: MySaveableObject
private string _prefix;
private IEnumerable<string> _local_names;
private Trackable _trackable;
private bool _call_with_mapped_captures;
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict.
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names,
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name)
{ {
_prefix = prefix;
_trackable = obj;
_local_names = local_names;
_call_with_mapped_captures = call_with_mapped_captures;
} }

// TODO: complete this class.
}
public static class saveable_object_util
{
/// <summary> /// <summary>
/// Returns the variables and names that will be used for a Saver. /// Returns the variables and names that will be used for a Saver.
/// </summary> /// </summary>
@@ -57,7 +77,7 @@ namespace Tensorflow
} }


/// <summary> /// <summary>
/// Create `SaveableObject`s from an operation.
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`.
/// </summary> /// </summary>
/// <param name="op"></param> /// <param name="op"></param>
/// <param name="name"></param> /// <param name="name"></param>
@@ -79,6 +99,74 @@ namespace Tensorflow
} }
} }


/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name)
{
// The `op` maybe `Variable` or `Trackable`.
if (obj is BaseResourceVariable)
{
var variable = obj as BaseResourceVariable;
if (variable.InGraphMode)
{
yield return new ResourceVariableSaveable(variable.GraphElement, "", name);
}
else
{
Debug.Assert(variable is ResourceVariable);
yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name);
}
}
else
{
foreach(var pair in saveable_objects_from_trackable(obj))
{
var attr = pair.Key;
var factory = pair.Value;
string full_name;
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY)
{
full_name = name;
}
else
{
full_name = name + "_" + attr;
}
if(factory.DataType == typeof(ResourceVariable))
{
var variable = factory.GetValueA();
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
{
yield return op;
}
}
else
{
var variable = factory.GetValueB();
foreach (var op in saveable_objects_for_op(variable, variable.name))
{
yield return op;
}
}
}
}
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name)
{
yield return obj;
}

public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true)
{ {
op_list = op_list.OrderBy(x => x.Name).ToArray(); op_list = op_list.OrderBy(x => x.Name).ToArray();
@@ -127,16 +215,55 @@ namespace Tensorflow
return names_to_saveables; return names_to_saveables;
} }


public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
{ {
// TODO: complete the implementation.
return obj.gather_saveables_for_checkpoint();
// skip the process of type `PythonState`

if (trackable_has_serialize_to_tensor(obj))
{
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME;
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors();

List<SaveSpec> specs = new();
List<string> local_names = new();
string prefix = SaveableCompat.get_saveable_name(obj) ?? "";
foreach(var pair in tensor_dict)
{
var tensor_name = pair.Key;
var maybe_tensor = pair.Value;
local_names.Add(tensor_name);
string spec_name = name + TrackableUtils.escape_local_name(tensor_name);

IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.DataType == typeof(Tensor))
{
internal_dict= new Dictionary<string, Tensor>();
internal_dict[""] = maybe_tensor.GetValueA();
}
else
{
internal_dict = maybe_tensor.GetValueB();
}

foreach(var item in internal_dict)
{
specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
}
}
Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return res;
}
else
{
return obj.gather_saveables_for_checkpoint();
}
} }


public static bool trackable_has_serialize_to_tensor(Trackable obj) public static bool trackable_has_serialize_to_tensor(Trackable obj)
{ {
// TODO: implement it.
return false;
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable);
} }


internal static string convert_to_string(string x) internal static string convert_to_string(string x)
@@ -158,27 +285,28 @@ namespace Tensorflow
public Trackable Obj => _obj; public Trackable Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables; public IList<MySaveableObject> mySaveables=> _saveables;


public override IDictionary<string, object> serialize_to_tensors()
public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{ {
return saveable_objects_to_tensor_dict(_saveables);
return saveable_object_to_tensor_dict(_saveables);
} }


/// <summary> /// <summary>
/// Converts a list of SaveableObjects to a tensor dictionary. /// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary> /// </summary>
/// <param name="saveables"></param> /// <param name="saveables"></param>
public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables)
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
{ {
Dictionary<string, object> tensor_dict = new();
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach (var saveable in saveables) foreach (var saveable in saveables)
{ {
foreach(var spec in saveable.specs) foreach(var spec in saveable.specs)
{ {
// skip the check that if `spec` is callable.
var name = saveable_object_util.convert_to_string(spec.name); var name = saveable_object_util.convert_to_string(spec.name);
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec)) if (!string.IsNullOrEmpty(slice_spec))
{ {
throw new NotImplementedException();
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor;
} }
else else
{ {


+ 33
- 15
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -16,7 +16,10 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving; using Tensorflow.ModelSaving;
using Tensorflow.Training; using Tensorflow.Training;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -39,8 +42,8 @@ namespace Tensorflow.Train


protected IList<TrackableReference> _unconditional_checkpoint_dependencies; protected IList<TrackableReference> _unconditional_checkpoint_dependencies;


protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
new Dictionary<string, ResourceVariable>();
protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
private bool _manual_tracking = true; private bool _manual_tracking = true;


private static Trackable _none = new Function(); private static Trackable _none = new Function();
@@ -94,9 +97,13 @@ namespace Tensorflow.Train
// assign again. It will add this variable to our dependencies, and if there // assign again. It will add this variable to our dependencies, and if there
// is a non-trivial restoration queued, it will handle that. This also // is a non-trivial restoration queued, it will handle that. This also
// handles slot variables. // handles slot variables.
if (!args.Overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: args.Name,
overwrite: args.Overwrite);
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}
else else
return new_variable; return new_variable;
} }
@@ -122,13 +129,16 @@ namespace Tensorflow.Train
/// </summary> /// </summary>
public void _maybe_initialize_trackable() public void _maybe_initialize_trackable()
{ {
if(_unconditional_checkpoint_dependencies is not null)
{
return;
}
_self_update_uid = -1; _self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); _unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>(); _unconditional_dependency_names = new Dictionary<string, Trackable>();
} }


// TODO: cache
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache)
{ {
_maybe_initialize_trackable(); _maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
@@ -139,8 +149,8 @@ namespace Tensorflow.Train
_maybe_initialize_trackable(); _maybe_initialize_trackable();
if (!_manual_tracking) return trackable; if (!_manual_tracking) return trackable;
var new_reference = new TrackableReference(name, trackable); var new_reference = new TrackableReference(name, trackable);
var current_object = _lookupup_dependency(name);
var current_object = _lookup_dependency(name);
if(current_object is null) if(current_object is null)
{ {
_unconditional_checkpoint_dependencies.Add(new_reference); _unconditional_checkpoint_dependencies.Add(new_reference);
@@ -170,7 +180,7 @@ namespace Tensorflow.Train
// TODO: complete the implementation. // TODO: complete the implementation.
} }


public virtual Trackable? _lookupup_dependency(string name)
public virtual Trackable? _lookup_dependency(string name)
{ {
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency;
else return null; else return null;
@@ -199,8 +209,8 @@ namespace Tensorflow.Train
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>());
} }


public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null,
IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null)
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map,
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null)
{ {
var (self_object_map, self_tensor_map) = map_resources(options); var (self_object_map, self_tensor_map) = map_resources(options);
foreach (var pair in self_object_map) foreach (var pair in self_object_map)
@@ -215,9 +225,17 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList(); return self_tensor_map.Keys.ToList();
} }


public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{ {
return _self_saveable_object_factories;
if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
throw new NotImplementedException();
}
else
{
return _self_saveable_object_factories;
}
} }


/// <summary> /// <summary>
@@ -229,7 +247,7 @@ namespace Tensorflow.Train
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
/// <exception cref="NotImplementedException"></exception> /// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, object> serialize_to_tensors()
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }


+ 26
- 2
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Exceptions; using Tensorflow.Exceptions;
using Tensorflow.Train; using Tensorflow.Train;
@@ -22,7 +23,7 @@ public static class TrackableUtils
private static string _ESCAPE_CHAR = "."; private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{ {
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name)));
@@ -145,4 +146,27 @@ public static class TrackableUtils
return $"root.{string.Join(".", paths.Select(x => x.Name))}"; return $"root.{string.Join(".", paths.Select(x => x.Name))}";
} }
} }

/// <summary>
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key.
/// </summary>
/// <param name="key"></param>
/// <param name="prefix"></param>
/// <returns></returns>
public static string extract_local_name(string key, string? prefix = null)
{
if(prefix is null)
{
prefix = "";
}
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix;
try
{
return key.Substring(key.IndexOf(search_key) + search_key.Length);
}
catch(ArgumentOutOfRangeException)
{
return key;
}
}
} }

+ 2
- 1
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -9,6 +9,7 @@ using System.Runtime.InteropServices;
using System.Text; using System.Text;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras; using Tensorflow.Keras;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation; using Tensorflow.Operations.Activation;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.ApiDef.Types; using static Tensorflow.ApiDef.Types;
@@ -243,7 +244,7 @@ namespace Tensorflow.Training
_last_wrapped_list_snapshot = new List<Trackable>(_storage); _last_wrapped_list_snapshot = new List<Trackable>(_storage);
} }


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
check_external_modification(); check_external_modification();
if (_non_append_mutation_value) if (_non_append_mutation_value)


+ 3
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -4,6 +4,8 @@ using Tensorflow.Eager;
using Tensorflow.Variables; using Tensorflow.Variables;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Collections.Generic;
using Tensorflow.ModelSaving;


namespace Tensorflow namespace Tensorflow
{ {
@@ -20,6 +22,7 @@ namespace Tensorflow
public string UniqueId => _unique_id; public string UniqueId => _unique_id;


protected bool _in_graph_mode; protected bool _in_graph_mode;
internal bool InGraphMode => _in_graph_mode;


protected bool _trainable; protected bool _trainable;
public bool Trainable => _trainable; public bool Trainable => _trainable;


+ 9
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -17,7 +17,9 @@
using Google.Protobuf; using Google.Protobuf;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Checkpoint;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -235,5 +237,12 @@ namespace Tensorflow
{ {
return _graph_element.eval(session); return _graph_element.eval(session);
} }

public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
return res;
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils; using Tensorflow.Keras.Utils;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine
return output_tensors; return output_tensors;
} }


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value); .ToDictionary(x => x.Key, x => x.Value);


+ 4
- 3
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train; using Tensorflow.Train;
@@ -9,16 +10,16 @@ public abstract partial class Layer
{ {
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);


public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;


public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
IDictionary<string, Trackable> children; IDictionary<string, Trackable> children;
if (save_type == SaveType.SAVEDMODEL) if (save_type == SaveType.SAVEDMODEL)
{ {
// TODO: deal with cache.
Debug.Assert(cache is not null);
children = TrackableSavedModelSaver.trackable_children(cache); children = TrackableSavedModelSaver.trackable_children(cache);
} }
else else


+ 22
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine


ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value; public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;
public Tensor[] input
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].input_tensors;
}
return null;
}
}
public Dictionary<int, List<INode>> NodesByDepth { get; set; } public Dictionary<int, List<INode>> NodesByDepth { get; set; }
public Shape OutputShape => inboundNodes[0].Outputs.shape;
public Shape OutputShape
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].Outputs.shape;
}
return null;
}
}
protected List<ILayer> _self_tracked_trackables; protected List<ILayer> _self_tracked_trackables;


public Layer(LayerArgs args) public Layer(LayerArgs args)


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
bool include_optimizer = true, bool include_optimizer = true,
string save_format = "tf", string save_format = "tf",
SaveOptions? options = null, SaveOptions? options = null,
IDictionary<string, ConcreteFunction>? signatures = null,
ConcreteFunction? signatures = null,
bool save_traces = true) bool save_traces = true)
{ {
if (save_format != "pb") if (save_format != "pb")


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;
@@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine
} }
} }


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
if(save_type == SaveType.SAVEDMODEL) if(save_type == SaveType.SAVEDMODEL)
{ {


+ 2
- 7
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;


public partial class KerasSavedModelUtils public partial class KerasSavedModelUtils
{ {
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures,
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
SaveOptions? options, bool save_traces = true) SaveOptions? options, bool save_traces = true)
{ {
if (!overwrite && File.Exists(filepath)) if (!overwrite && File.Exists(filepath))
@@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils
} }


var metadata = generate_keras_metadata(saved_nodes, node_paths); var metadata = generate_keras_metadata(saved_nodes, node_paths);
using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate,
FileAccess.Write))
{
var writer = new StreamWriter(f);
writer.Write(metadata.ToString());
}
File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray());


if (!include_optimizer) if (!include_optimizer)
{ {


+ 2
- 2
src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs View File

@@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils
/// <param name="layer"></param> /// <param name="layer"></param>
/// <param name="serialization_cache"></param> /// <param name="serialization_cache"></param>
/// <returns></returns> /// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache)
public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs.


@@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils
/// <param name="layer"></param> /// <param name="layer"></param>
/// <param name="serialization_cache"></param> /// <param name="serialization_cache"></param>
/// <returns></returns> /// <returns></returns>
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache)
public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
// TODO: deal with type `RevivedLayer` and `Sequential`. // TODO: deal with type `RevivedLayer` and `Sequential`.




+ 3
- 4
src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs View File

@@ -18,12 +18,12 @@ public abstract class SavedModelSaver
public abstract string TrackingMetadata { get; } public abstract string TrackingMetadata { get; }


public abstract IDictionary<string, Trackable> objects_to_serialize( public abstract IDictionary<string, Trackable> objects_to_serialize(
IDictionary<string, object> serialization_cache);
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache);


public abstract IDictionary<string, Trackable> functions_to_serialize( public abstract IDictionary<string, Trackable> functions_to_serialize(
IDictionary<string, object> serialization_cache);
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache);


public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache)
public IDictionary<string, Trackable> trackable_children(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
if (!KerasSavedModelUtils.ShouldHaveTraces) if (!KerasSavedModelUtils.ShouldHaveTraces)
{ {
@@ -31,7 +31,6 @@ public abstract class SavedModelSaver
} }


var children = objects_to_serialize(serialization_cache); var children = objects_to_serialize(serialization_cache);

return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value))
.ToDictionary(x => x.Key, x => x.Value); .ToDictionary(x => x.Key, x => x.Value);
} }


+ 20
- 8
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver
get => Constants.LAYER_IDENTIFIER; get => Constants.LAYER_IDENTIFIER;
} }


public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache)
public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
return get_serialized_attributes(serialization_cache).ObjectsToSerialize; return get_serialized_attributes(serialization_cache).ObjectsToSerialize;
} }


public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache)
public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
return get_serialized_attributes(serialization_cache).FunctionsToSerialize; return get_serialized_attributes(serialization_cache).FunctionsToSerialize;
} }
@@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver
/// Generates or retrieves serialized attributes from cache. /// Generates or retrieves serialized attributes from cache.
/// </summary> /// </summary>
/// <param name="serialization_cache"></param> /// <param name="serialization_cache"></param>
protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache)
protected ISerializedAttributes get_serialized_attributes(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
// TODO: deal with cache. // TODO: deal with cache.
IDictionary<Trackable, ISerializedAttributes> keras_cache;
if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY))
{
keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY];
}
else
{
serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary<Trackable, ISerializedAttributes>();
}
if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj];


var serialized_attr = SerializedAttributes.Create(_obj);
var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj);


// TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`.
if (KerasSavedModelUtils.should_skip_serialization(_obj)) if (KerasSavedModelUtils.should_skip_serialization(_obj))
@@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver
/// Returns dictionary of serialized attributes. /// Returns dictionary of serialized attributes.
/// </summary> /// </summary>
/// <param name="serialization_cache"></param> /// <param name="serialization_cache"></param>
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache)
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{ {
var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache);
var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache);
@@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver
metadata["trainable"] = _obj.Trainable; metadata["trainable"] = _obj.Trainable;
// metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["expects_training_arg"] = _obj._expects_training_arg;
// metadata["dtype"] = policy.serialize(_obj._dtype_policy) // metadata["dtype"] = policy.serialize(_obj._dtype_policy)
metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape);
metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape);
// metadata["stateful"] = _obj.stateful; // metadata["stateful"] = _obj.stateful;
// metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["must_restore_from_config"] = _obj.must_restore_from_config;
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
@@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver
} }
} }


public static LayerConfig get_serialized(Layer obj)
public static IDictionary<string, object> get_serialized(Layer obj)
{ {
return generic_utils.serialize_keras_object(obj);
// TODO: complete the implmentation (need to revise `get_config`).
return new Dictionary<string, object>();
//return generic_utils.serialize_keras_object(obj);
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
/// <summary> /// <summary>
/// Class that tracks and validates all serialization attributes. /// Class that tracks and validates all serialization attributes.
/// </summary> /// </summary>
public abstract class SerializedAttributes
public abstract class SerializedAttributes: ISerializedAttributes
{ {
protected IDictionary<string, Trackable?> _object_dict; protected IDictionary<string, Trackable?> _object_dict;
protected IDictionary<string, Trackable?> _function_dict; protected IDictionary<string, Trackable?> _function_dict;


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/SaveTest.cs View File

@@ -50,11 +50,11 @@ public class SaveTest
{ {
TrainDir = "mnist", TrainDir = "mnist",
OneHot = false, OneHot = false,
ValidationSize = 50000,
ValidationSize = 0,
}).Result; }).Result;
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.save("", save_format:"pb");
model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb");
} }
} }

Loading…
Cancel
Save