Browse Source

Add more facilities to the saved model framework.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
c4114d5f18
13 changed files with 685 additions and 153 deletions
  1. +253
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  2. +8
    -10
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs
  4. +0
    -109
      src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs
  5. +7
    -1
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  6. +191
    -0
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  7. +36
    -0
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  8. +10
    -0
      src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs
  9. +50
    -1
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs
  11. +26
    -26
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  12. +50
    -0
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  13. +37
    -5
      src/TensorFlowNET.Core/Training/Trackable.cs

+ 253
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -0,0 +1,253 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint
{
internal record class TrackableData(
// A trackable in the root Trackable object graph.
Trackable trackable,
// The index at which the Trackable appears in TrackableObjectGraph.nodes.
int node_id,
// The BFS-generated path from the root object / used to generate readable checkpoint keys.
string object_name,
// A list of ObjectReference for each child connected to this Trackable.
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto,
// A list of SlotVariableReference to save to the object (only valid for Optimizer objects).
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto,
// The object to save to checkpoint. Usually this is the same as `trackable`,
// but can differ when the the caller wants to specify a different object to
// save. For example, when saving checkpoints asynchronously, variables are
// copied to the CPU. `object_to_save` is set as the copied variable.
Trackable object_to_save
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data);

var object_graph_proto = fill_object_graph_proto(trackable_data);

var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);

Dictionary<Tensor, string> feed_additions;
if(cache is null)
{
feed_additions = null;
serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures,
cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value);
}
else
{
feed_additions = null;
// TODO: deal with cache.
throw new NotFiniteNumberException();
}

CheckPointUtils.add_checkpoint_values_check(object_graph_proto);

return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
}

private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach(var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}
Dictionary<Trackable, int> node_ids = new();
for(int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}
var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
List<TrackableData> trackable_data = new();
foreach(var trackable in trackable_objects)
{
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new();
foreach(var child in graph_view.list_children(trackable))
{
children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{
NodeId = node_ids[child.Refer],
LocalName = child.Name
});
}
slot_variables.TryGetValue(trackable, out var slot_variable);
trackable_data.Add(new TrackableData(
trackable: trackable,
node_id: node_ids[trackable],
object_name: object_names[trackable],
children_proto: children_proto,
slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(),
object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map)
));
}
return (trackable_data, node_ids);
}

private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data)
{
TrackableObjectGraph object_graph_proto = new();
for(int i = 0; i < trackable_data.Count; i++)
{
var td = trackable_data[i];
Debug.Assert(td.node_id == i);
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
}
return object_graph_proto;
}

/// <summary>
/// Creates dictionary of tensors to checkpoint, and updates the proto.
/// </summary>
/// <param name="tensor_trackables"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
var trackable = td.object_to_save;
IDictionary<string, object> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
}
else
{
tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
}
if(trackable is not null)
{
serialized_tensors[trackable] = tensor_dict;
}
else
{
serialized_tensors[Trackable.None] = tensor_dict;
}
}
return serialized_tensors;
}

private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, object> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
}
else
{
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: revise the types and complete it
Dictionary<string, object> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
var maybe_tensor = pair.Value;
var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name);

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor is SaveSpec)
{
((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
}

if(object_graph_proto is not null)
{
object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{
Name = local_name,
CheckpointKey = checkpoint_key,
FullName = CheckPointUtils.get_full_name(trackable)
});
}
}
return tensor_dict;
}

/// <summary>
/// Gets tensors to serialize from a Trackable with legacy SaveableObjects.
/// </summary>
/// <param name="trackable_data"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();
object_names[trackable_data.trackable] = trackable_data.object_name;
Dictionary<Trackable, Trackable> object_map = new();
object_map[trackable_data.trackable] = trackable_data.object_to_save;

var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map);
var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map,
call_with_mapped_captures, saveables_cache: null);
var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects);
return (trackable, trackable.serialize_to_tensors());
}

private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto)
{
Dictionary<string, IDictionary<string, Trackable>> registered_savers = new();
foreach(var pair in registered_trackables)
{
foreach(var td in pair.Value)
{
if (registered_savers.ContainsKey(pair.Key))
{
registered_savers[pair.Key] = new Dictionary<string, Trackable>();
}
else
{
registered_savers[pair.Key][td.object_name] = td.object_to_save;
}

var object_proto = object_graph_proto.Nodes[td.node_id];
// TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`.
}
}
return registered_savers;
}

private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data)
{
List<TrackableData> tensor_trackables = new();
List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder.
Dictionary<string, IList<TrackableData>> registered_trackables = new();

foreach(var td in trackable_data)
{
// TODO: deal with registration.
tensor_trackables.Add(td);
}
return (tensor_trackables, py_state_trackables, registered_trackables);
}
}
}

+ 8
- 10
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -7,6 +7,7 @@ using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Google.Protobuf;

namespace Tensorflow.Checkpoint;

@@ -47,19 +48,16 @@ public static class SaveUtilV1
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null)
{

Graph target_context;
if (to_graph is not null)
{
using (to_graph.as_default())
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
// tensorflow python: `with ops.device("/cpu:0")`
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
// tensorflow python: `with ops.device("/cpu:0")`
var serialized = graph_proto.ToByteString().ToString();
var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
else
{


+ 16
- 0
src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint
{
internal static class SaveableCompat
{
public static string? get_saveable_name(Trackable cls_or_obj)
{
// TODO: implement it with Attribute.
return null;
}
}
}

+ 0
- 109
src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs View File

@@ -1,109 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;

namespace Tensorflow.Checkpoint;

public class TrackableSaver
{
private ObjectGraphView _graph_view;
private EagerTensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private void gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
throw new NotImplementedException();
}

private (EagerTensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
throw new NotImplementedException();
}
// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}

Dictionary<Tensor, string> feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", dtypes.variant);
_file_prefix_feed_tensor = constant_op.constant("", dtypes.variant);
}

object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant);
object_graph_tensor = null;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);

if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}

if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -21,9 +21,14 @@ public class TrackableView
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
{
obj._maybe_initialize_trackable();
Dictionary<string, Trackable> children = new();
// Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process.
return obj._trackable_children(save_type);
foreach(var pair in obj._trackable_children(save_type))
{
children[pair.Key] = pair.Value;
}
return children;
}
public Trackable Root
@@ -50,6 +55,7 @@ public class TrackableView
{
List<Trackable> bfs_sorted = new();
Queue<Trackable> to_visit = new();
to_visit.Enqueue(Root);
Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new();
node_paths[this.Root] = new List<TrackableReference>();
while (!to_visit.empty())


+ 191
- 0
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -0,0 +1,191 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Checkpoint;

/// <summary>
/// Saves and restores a `Trackable` object and its dependencies.
/// </summary>
public class TrackableSaver
{
private ObjectGraphView _graph_view;
private Tensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
private Dictionary<Trackable, Trackable>? _object_map = null;
private object? _cache = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);

// TODO: cache.

if(object_graph_tensor is null)
{
// tensorflow python: `with ops.device("/cpu:0"):`
object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
}
else
{
feed_additions[object_graph_tensor] = graph_proto.ToString();
}
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
}
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] { save_op }))
{
_cached_save_operation = array_ops.identity(file_prefix);
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] {save_op} ))
{
_cached_save_operation = array_ops.identity(tf.constant(file_prefix));
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}

Dictionary<Tensor, string> feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
_file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
}

object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);

if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}

if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Buffers.Text;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.OptimizerOptions.Types;

namespace Tensorflow.Checkpoint
{
/// <summary>
/// Saves checkpoints directly from multiple devices.
/// Note that this is a low-level utility which stores Tensors in the keys
/// specified by `SaveableObject`s.Higher-level utilities for object-based
/// checkpointing are built on top of it.
/// </summary>
public class MultiDeviceSaver
{
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{

}

public Operation? save(string file_prefix, CheckpointOptions? options= null)
{
throw new NotImplementedException();
}

public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
throw new NotImplementedException();
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs View File

@@ -205,6 +205,16 @@ namespace Tensorflow {
slotVariables_ = slot;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot,
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children
)
{
OnConstruction();
slotVariables_ = slot;
children_ = children;
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]


+ 50
- 1
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -1,4 +1,10 @@
namespace Tensorflow.Train
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
public abstract class AutoTrackable : Trackable
{
@@ -17,5 +23,48 @@
}
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
{
if(save_type != SaveType.SAVEDMODEL)
{
return base._trackable_children(save_type, cache);
}

Dictionary<string, Trackable> functions = new();
// TODO: process of logs.
var properties = this.GetType().GetProperties();
foreach ( var property in properties )
{
string name = property.Name;
object value = property.GetValue(this, null);
if(value is Function || value is ConcreteFunction)
{
functions[name] = (Trackable)value;
}
}

// TODO: process the type `core_types.GenericFunction`.

Dictionary<string, Trackable> children = new();
foreach(var pair in CheckpointDependencies)
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction) // or Generic function
{
continue;
}
if(functions.ContainsKey(name) && functions[name] != child)
{
throw new ValueError($"Can't save object because it has multiple children with the same " +
$"name. Object: {this}, attribute name: {name}, child 1: " +
$"{child}, child 2: {functions[name]}");
}
children[name] = child;
}

return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
public string slice_spec => _slice_spec;

private string _name;
public string name => _name;
public string name { get => _name; set => _name = value; }

private TF_DataType _dtype;
public TF_DataType dtype => _dtype;


+ 26
- 26
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -134,35 +134,33 @@ public static partial class SavedModelUtils
Dictionary<Trackable, Trackable> object_map;
Dictionary<Tensor, Tensor> tensor_map;
AssetInfo asset_info;
using (var g = exported_graph.as_default())
exported_graph.as_default();
(object_map, tensor_map, asset_info) = saveable_view.map_resources();
// TODO: deal with signatures.
if (save_custom_gradients)
{
(object_map, tensor_map, asset_info) = saveable_view.map_resources();
// TODO: deal with signatures.
if (save_custom_gradients)
{
// TODO: trace gradient functions.
}
// TODO: trace gradient functions.
}

foreach (var resource_initializer_function in resource_initializers)
{
// List<Trackable> asset_dependencies = new();
// TODO: deal with initializers
}
// using(ops.control_dependencies(...))
var init_op = control_flow_ops.no_op();
if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY))
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name);
}
else
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef();
}
// Lack `CopyFrom` API
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY]
foreach (var resource_initializer_function in resource_initializers)
{
// List<Trackable> asset_dependencies = new();
// TODO: deal with initializers
}

// using(ops.control_dependencies(...))
var init_op = control_flow_ops.no_op();
if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY))
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name);
}
else
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef();
}
// Lack `CopyFrom` API
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY]

foreach (var obj in object_map.Values)
{
obj._maybe_initialize_trackable();
@@ -180,11 +178,13 @@ public static partial class SavedModelUtils
verify_ops(graph_def, namespace_whitelist);

meta_graph_def.GraphDef = new GraphDef(graph_def);
meta_graph_def.MetaInfoDef = new();
meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING);
meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION;
// TODO: add git version.
meta_graph_def.MetaInfoDef.TensorflowGitVersion = "";
meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
meta_graph_def.MetaInfoDef.StrippedOpList = new();
meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef));
meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs);


+ 50
- 0
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -138,5 +138,55 @@ namespace Tensorflow
// TODO: implement it.
return false;
}

internal static string convert_to_string(string x)
{
return tf.compat.as_str(x);
}
}

public class SaveableCompatibilityConverter: Trackable
{
private Trackable _obj;
private IList<MySaveableObject> _saveables;
public SaveableCompatibilityConverter(Trackable obj, IList<MySaveableObject> saveables)
{
_obj= obj;
_saveables= saveables;
}

public Trackable Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, object> serialize_to_tensors()
{
return saveable_objects_to_tensor_dict(_saveables);
}

/// <summary>
/// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary>
/// <param name="saveables"></param>
public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables)
{
Dictionary<string, object> tensor_dict = new();
foreach (var saveable in saveables)
{
foreach(var spec in saveable.specs)
{
var name = saveable_object_util.convert_to_string(spec.name);
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec))
{
throw new NotImplementedException();
}
else
{
tensor_dict[name] = spec.tensor;
}
}
}
return tensor_dict;
}
}
}

+ 37
- 5
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -34,18 +34,35 @@ namespace Tensorflow.Train
public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON";
}
protected int _self_update_uid;
protected IDictionary<string, Trackable> _unconditional_dependency_names =
new Dictionary<string, Trackable>();
protected IDictionary<string, Trackable> _unconditional_dependency_names;

protected IList<TrackableReference> _unconditional_checkpoint_dependencies = new List<TrackableReference>();
protected IList<TrackableReference> _unconditional_checkpoint_dependencies;

protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
new Dictionary<string, ResourceVariable>();

private static Trackable _none = new Function();
/// <summary>
/// This is a trick for that CSharp does not allow the key of `Dictionary` to be null.
/// The `None` can be any object that inherits `Trackable`.
/// This Property is supposed to be used only internal.
/// </summary>
public static Trackable None
{
get
{
return _none;
}
}
public virtual string ObjectIdentifier
{
get => "_generic_user_object";
}
public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; }
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }

/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary>
@@ -99,8 +116,9 @@ namespace Tensorflow.Train
/// </summary>
public void _maybe_initialize_trackable()
{
// _self_unconditional_checkpoint_dependencies = []
_self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>();
}

// TODO: cache
@@ -153,6 +171,20 @@ namespace Tensorflow.Train
{
return _self_saveable_object_factories;
}

/// <summary>
/// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors`
/// if you are defining a custom resource or variable with custom ops.
/// Otherwise, please store the state of your trackable in `tf.Variable` objects
/// and add them to Trackable object hierarchy using `setattr` (for subclasses
/// of `AutoTrackable`) or overriding the `_trackable_children` method.
/// </summary>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, object> serialize_to_tensors()
{
throw new NotImplementedException();
}
}

public record class TrackableReference(string Name, Trackable Refer);


Loading…
Cancel
Save