Browse Source

Merge branch 'master' into add_pb_model_save

pull/976/head
AsakusaRinne 2 years ago
parent
commit
c9c339d309
100 changed files with 5151 additions and 109 deletions
  1. +22
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  2. +152
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  4. +64
    -0
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  5. +255
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  6. +222
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  7. +16
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs
  8. +82
    -0
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  9. +195
    -0
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  10. +559
    -0
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  11. +68
    -0
      src/TensorFlowNET.Core/DisposableObject.cs
  12. +31
    -0
      src/TensorFlowNET.Core/Eager/execute.cs
  13. +14
    -0
      src/TensorFlowNET.Core/Exceptions/AssertionError.cs
  14. +85
    -1
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  15. +2
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  16. +9
    -2
      src/TensorFlowNET.Core/Functions/Function.cs
  17. +32
    -14
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  18. +13
    -4
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
  19. +19
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  20. +41
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
  21. +15
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
  22. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  23. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  24. +17
    -14
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  25. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
  26. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
  27. +5
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
  28. +50
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
  29. +48
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  30. +73
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  31. +67
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  32. +30
    -1
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  33. +4
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  34. +15
    -0
      src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs
  35. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
  36. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  37. +5
    -2
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  38. +35
    -0
      src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs
  39. +21
    -0
      src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
  40. +43
    -1
      src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
  41. +10
    -1
      src/TensorFlowNET.Core/NumPy/Axis.cs
  42. +3
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  43. +10
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  44. +9
    -1
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  45. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  46. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  47. +6
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  48. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  49. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  50. +11
    -0
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  51. +13
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  52. +5
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  53. +7
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  54. +55
    -4
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  55. +32
    -0
      src/TensorFlowNET.Core/Operations/io_ops.cs
  56. +60
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  57. +9
    -1
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  58. +16
    -0
      src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs
  59. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  60. +18
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  61. +67
    -2
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  62. +12
    -0
      src/TensorFlowNET.Core/Training/IWithTrackable.cs
  63. +9
    -0
      src/TensorFlowNET.Core/Training/LayerUtils.cs
  64. +6
    -1
      src/TensorFlowNET.Core/Training/Optimizer.cs
  65. +28
    -0
      src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs
  66. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs
  67. +36
    -2
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  68. +11
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs
  69. +133
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs
  70. +33
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs
  71. +17
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  72. +9
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs
  73. +299
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  74. +10
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs
  75. +22
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs
  76. +269
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  77. +53
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs
  78. +107
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  79. +57
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs
  80. +254
    -2
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  81. +186
    -6
      src/TensorFlowNET.Core/Training/Trackable.cs
  82. +172
    -0
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  83. +370
    -0
      src/TensorFlowNET.Core/Training/data_structures.cs
  84. +71
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  85. +1
    -0
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  86. +3
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  87. +3
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  88. +70
    -0
      src/TensorFlowNET.Core/Variables/UninitializedVariable.cs
  89. +18
    -0
      src/TensorFlowNET.Core/ops.cs
  90. +17
    -14
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  91. +50
    -0
      src/TensorFlowNET.Keras/Engine/Functional.cs
  92. +32
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  93. +32
    -4
      src/TensorFlowNET.Keras/Engine/Layer.cs
  94. +5
    -0
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  95. +17
    -2
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  96. +19
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  97. +1
    -0
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  98. +1
    -0
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  99. +5
    -4
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  100. +2
    -1
      src/TensorFlowNET.Keras/Layers/Attention/Attention.cs

+ 22
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Text;

namespace Tensorflow
{
public partial class tensorflow
@@ -23,6 +25,26 @@ namespace Tensorflow
public class CompatApi
{
public CompatV1Api v1 { get; } = new CompatV1Api();

internal string as_text(string bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return bytes_or_text;
}
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return encoding.GetString(bytes_or_text);
}
internal string as_str(string bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
}

public bool executing_eagerly()


+ 152
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -0,0 +1,152 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint;

public static class CheckPointUtils
{
private static string _ESCAPE_CHAR = ".";
public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
}

public static
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
{
var non_slot_objects = trackable_objects.ToList();
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables = new();
foreach (var trackable in non_slot_objects)
{
if (trackable is not Optimizer)
{
continue;
}

var optim = (Optimizer)trackable;
var slot_names = optim.get_slot_names();
foreach (var slot_name in slot_names)
{
for (int original_variable_node_id = 0;
original_variable_node_id < non_slot_objects.Count;
original_variable_node_id++)
{
var original_variable = non_slot_objects[original_variable_node_id];
IVariableV1 slot_variable;
if (original_variable is not IVariableV1)
{
slot_variable = null;
}
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
if(slot_variable is null) continue;

// There're some problems about the inherits of `Variable` and `Trackable`.
throw new NotImplementedException();
}
}
}

return slot_variables;
}

public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
{
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
{
return trackable;
}
else
{
return possible_res;
}
}

public static string get_full_name(Trackable variable)
{
// TODO: This state is not correct, the whole framework need to be updated in the future.
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
{
return "";
}
// skip the check of attribute `_save_slice_info` .

// TODO: Need to be revised!!!
Debug.Assert(variable is BaseResourceVariable);
return ((BaseResourceVariable)variable).Name;
}

public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
{
HashSet<int> checkpointed_trackables = new();
Dictionary<int, HashSet<int>> parents = new();
for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
{
var object_proto = object_graph_proto.Nodes[i];
// skip the process of registered saver.
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
{
checkpointed_trackables.Add(i);
}

foreach (var child_proto in object_proto.Children)
{
var child = child_proto.NodeId;
if (!parents.ContainsKey(child))
{
parents[child] = new HashSet<int>();
}

parents[child].Add(i);
}
}

Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
while (to_visit.Count > 0)
{
var trackable = to_visit.Dequeue();
if (!parents.ContainsKey(trackable)) continue;
var current_parents = parents[trackable];
foreach (var parent in current_parents)
{
checkpointed_trackables.Add(parent);
if (parents.ContainsKey(parent))
{
to_visit.Enqueue(parent);
}
}
parents.Remove(trackable);
}
// TODO: Complete it after supporting checkpoint.
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
// {
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
}

+ 5
- 0
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -0,0 +1,5 @@
namespace Tensorflow.Checkpoint;

public record class CheckpointOptions(
string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);

+ 64
- 0
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -0,0 +1,64 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Serilog.Debugging;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;

public class ObjectGraphView: TrackableView, ICloneable
{
protected IEnumerable<TrackableReference>? _attached_dependencies;
// TODO: attached_dependencies
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root)
{
_attached_dependencies = attached_dependencies;
}

public object Clone()
{
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__
return new ObjectGraphView(Root, _attached_dependencies);
}

public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value.
if (obj == Root && _attached_dependencies is not null)
{
res.AddRange(_attached_dependencies);
}

return res;
}
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
}
public IEnumerable<TrackableReference>? AttachedDependencies
{
get => _attached_dependencies;
}

public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
return base._descendants_with_paths();
}

// TODO: complete the implementation
public void serialize_object_graph(object? saveables_cache = null)
{
throw new NotImplementedException();
}
// TODO: complete the implementation
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null)
{
throw new NotImplementedException();
}
}

+ 255
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -0,0 +1,255 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint
{
internal record class TrackableData(
// A trackable in the root Trackable object graph.
Trackable trackable,
// The index at which the Trackable appears in TrackableObjectGraph.nodes.
int node_id,
// The BFS-generated path from the root object / used to generate readable checkpoint keys.
string object_name,
// A list of ObjectReference for each child connected to this Trackable.
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto,
// A list of SlotVariableReference to save to the object (only valid for Optimizer objects).
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto,
// The object to save to checkpoint. Usually this is the same as `trackable`,
// but can differ when the the caller wants to specify a different object to
// save. For example, when saving checkpoints asynchronously, variables are
// copied to the CPU. `object_to_save` is set as the copied variable.
Trackable object_to_save
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data);

var object_graph_proto = fill_object_graph_proto(trackable_data);

var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);

Dictionary<Tensor, object> feed_additions;
if(cache is null)
{
feed_additions = null;
serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures,
cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value);
}
else
{
feed_additions = null;
// TODO: deal with cache.
throw new NotFiniteNumberException();
}

CheckPointUtils.add_checkpoint_values_check(object_graph_proto);

return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
}

private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach(var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}
Dictionary<Trackable, int> node_ids = new();
for(int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}
var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
List<TrackableData> trackable_data = new();
foreach(var trackable in trackable_objects)
{
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new();
foreach(var child in graph_view.list_children(trackable))
{
children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{
NodeId = node_ids[child.Refer],
LocalName = child.Name
});
}
slot_variables.TryGetValue(trackable, out var slot_variable);
trackable_data.Add(new TrackableData(
trackable: trackable,
node_id: node_ids[trackable],
object_name: object_names[trackable],
children_proto: children_proto,
slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(),
object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map)
));
}
return (trackable_data, node_ids);
}

private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data)
{
TrackableObjectGraph object_graph_proto = new();
for(int i = 0; i < trackable_data.Count; i++)
{
var td = trackable_data[i];
Debug.Assert(td.node_id == i);
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
}
return object_graph_proto;
}

/// <summary>
/// Creates dictionary of tensors to checkpoint, and updates the proto.
/// </summary>
/// <param name="tensor_trackables"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
}
else
{
tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
trackable = td.object_to_save;
}
if(trackable is not null)
{
serialized_tensors[trackable] = tensor_dict;
}
else
{
serialized_tensors[Trackable.None] = tensor_dict;
}
}
return serialized_tensors;
}

private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
}
else
{
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
var maybe_tensor = pair.Value;
var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name);

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor.GetValueA() is SaveSpec)
{
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
}

if(object_graph_proto is not null)
{
object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{
Name = local_name,
CheckpointKey = checkpoint_key,
FullName = CheckPointUtils.get_full_name(trackable)
});
}
}
return tensor_dict;
}

/// <summary>
/// Gets tensors to serialize from a Trackable with legacy SaveableObjects.
/// </summary>
/// <param name="trackable_data"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();
object_names[trackable_data.trackable] = trackable_data.object_name;
Dictionary<Trackable, Trackable> object_map = new();
object_map[trackable_data.trackable] = trackable_data.object_to_save;

var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map);
var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map,
call_with_mapped_captures, saveables_cache: null);
var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects);
return (trackable, trackable.serialize_to_tensors());
}

private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto)
{
Dictionary<string, IDictionary<string, Trackable>> registered_savers = new();
foreach(var pair in registered_trackables)
{
foreach(var td in pair.Value)
{
if (registered_savers.ContainsKey(pair.Key))
{
registered_savers[pair.Key] = new Dictionary<string, Trackable>();
}
else
{
registered_savers[pair.Key][td.object_name] = td.object_to_save;
}

var object_proto = object_graph_proto.Nodes[td.node_id];
// TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`.
}
}
return registered_savers;
}

private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data)
{
List<TrackableData> tensor_trackables = new();
List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder.
Dictionary<string, IList<TrackableData>> registered_trackables = new();

foreach(var td in trackable_data)
{
// TODO: deal with registration.
tensor_trackables.Add(td);
}
return (tensor_trackables, py_state_trackables, registered_trackables);
}
}
}

+ 222
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -0,0 +1,222 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Google.Protobuf;

namespace Tensorflow.Checkpoint;

public static class SaveUtilV1
{
public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
IDictionary<Trackable, Trackable>? object_map = null)
{
// According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md,
// till now only internal registrations are allowed. So, we won't return a saver in this function.
// The implementation of this function should be updated if tensorflow update it.
Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new();
foreach (var pair in object_names)
{
var trackable = pair.Key;
var object_name = pair.Value;
var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip the registration process.

List<CheckpointFactoryData> current_list = new();
foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save))
{
// treat name as key_suffix.
var name = name_and_factory.Key;
var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name);
current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key));
}

checkpoint_factory_map[trackable] = current_list;
}

return (checkpoint_factory_map, null);
}

public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null)
{
if (to_graph is not null)
{
var g = to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
g.Exit();
return (named_saveable_objects, registered_savers);
}
else
{
using (new ops.NullContextManager())
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
}
}

public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables);
var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph(
trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures,
saveables_cache);
CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers);
}

private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables)
{
TrackableObjectGraph object_graph_proto = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
var trackable = trackable_objects[i];
Debug.Assert(node_ids[trackable] == i);
TrackableObjectGraph.Types.TrackableObject object_proto;
if (slot_variables.TryGetValue(trackable, out var slots))
{
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots);
}
else
{
object_proto = new TrackableObjectGraph.Types.TrackableObject();
}
object_graph_proto.Nodes.Add(object_proto);
foreach (var child in graph_view.list_children(trackable))
{
object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{ NodeId = node_ids[child.Refer], LocalName = child.Name });
}
}

return object_graph_proto;
}

private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
bool call_with_mapped_captures, object? saveables_cache = null)
{
int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count);
for (int i = 0; i < cnt; i++)
{
Debug.Assert(node_ids[trackable_objects[i]] == i);
}

var (checkpoint_factory_map, unmmaped_registered_savers) =
get_checkpoint_factories_and_keys(object_names, object_map);
// skip the process of registered savers

var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map,
object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache);
return (named_saveable_objects, feed_additions, null);
}

public static (List<MySaveableObject>, object?) generate_saveable_objects(
IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map,
TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
List<MySaveableObject> named_saveable_objects = new();
foreach (var pair in checkpoint_factory_map)
{
var trackable = pair.Key;
var factory_data_list = pair.Value;
bool fill_object_proto = object_graph_proto is not null && node_ids is not null;
TrackableObjectGraph.Types.TrackableObject object_proto = null!;
if (fill_object_proto)
{
object_proto = object_graph_proto.Nodes[node_ids[trackable]];
}

var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip cache

foreach (var factory_data in factory_data_list)
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;

// TODO: oneflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new();
if (maybe_saveable.DataType == typeof(MySaveableObject))
{
saveables.Add(maybe_saveable.GetValueB());
}
else
{
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key));
}

foreach (var saveable in saveables)
{
if (!saveable.name.Contains(key))
{
throw new AssertionError($"The object {trackable} produced a SaveableObject with name " +
$"'{saveable.name}' for attribute '{name}'. Expected a name" +
$" containing '{key}'.");
}
}
// skip the process of PythonState
named_saveable_objects.AddRange(saveables);
if(!fill_object_proto) continue;

// skip the process of `TrackableSaveable` because of lack of APIs.

object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
}
}

return (named_saveable_objects, null);
}
}

public record class CheckpointFactoryData
(
Maybe<BaseResourceVariable, MySaveableObject> factory,
string name,
string checkpoint_key
);

+ 16
- 0
src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint
{
internal static class SaveableCompat
{
public static string? get_saveable_name(Trackable cls_or_obj)
{
// TODO: implement it with Attribute.
return null;
}
}
}

+ 82
- 0
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -0,0 +1,82 @@
using System;
using Tensorflow.Train;
using System.Collections.Generic;
using System.IO;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow.Checkpoint;

public class TrackableView
{
protected WeakReference<Trackable> _root_ref;
public TrackableView(Trackable obj)
{
_root_ref = new WeakReference<Trackable>(obj);
}

public TrackableView(WeakReference<Trackable> obj)
{
_root_ref = obj;
}
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
obj._maybe_initialize_trackable();
Dictionary<string, Trackable> children = new();
// Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process.
foreach (var pair in obj._trackable_children(save_type, cache))
{
children[pair.Key] = pair.Value;
}
return children;
}
public Trackable Root
{
get
{
if (_root_ref.TryGetTarget(out Trackable res))
{
return res;
}
else
{
throw new InvalidDataException(
"Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor.");
}
}
}
/// <summary>
/// Returns a list of all nodes and its paths from self.root using a breadth first traversal.
/// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths
/// </summary>
protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
{
List<Trackable> bfs_sorted = new();
Queue<Trackable> to_visit = new();
to_visit.Enqueue(Root);
Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new();
node_paths[this.Root] = new List<TrackableReference>();
while (!to_visit.empty())
{
var current_trackable = to_visit.Dequeue();
bfs_sorted.Add(current_trackable);
var children_dict = this.children(current_trackable);
foreach (var name in children_dict.Keys)
{
var dependency = children_dict[name];
if (!node_paths.ContainsKey(dependency))
{
var list = new List<TrackableReference>(node_paths[current_trackable]);
list.Add(new TrackableReference(name, dependency));
node_paths[dependency] = list;
to_visit.Enqueue(dependency);
}
}
}

return (bfs_sorted, node_paths);
}
}

+ 195
- 0
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -0,0 +1,195 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Checkpoint;

/// <summary>
/// Saves and restores a `Trackable` object and its dependencies.
/// </summary>
public class TrackableSaver
{
private ObjectGraphView _graph_view;
private Tensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
private Dictionary<Trackable, Trackable>? _object_map = null;
private object? _cache = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);

// TODO: cache.

if(object_graph_tensor is null)
{
tf.device("/cpu:0");
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
}
else
{
feed_additions[object_graph_tensor] = graph_proto.ToByteArray();
}
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (!serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>();
}
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] { save_op }))
{
_cached_save_operation = array_ops.identity(file_prefix);
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] {save_op} ))
{
_cached_save_operation = array_ops.identity(tf.constant(file_prefix));
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}

Dictionary<Tensor, object> feed_dict = new();
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
string file_prefix_to_save;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
_file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
}

object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
file_prefix_to_save = "";
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null;
file_prefix_to_save = file_prefix;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options);

if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}

if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}

+ 559
- 0
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -0,0 +1,559 @@
using System;
using System.Buffers.Text;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.OptimizerOptions.Types;
using static Tensorflow.Binding;
using System.Text.RegularExpressions;
using System.Linq;
using Tensorflow.Operations;
using Tensorflow.Training;
using Tensorflow.Graphs;
using System.Xml.Linq;
using System.Diagnostics;

namespace Tensorflow.Checkpoint
{
/// <summary>
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions.
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users.
/// </summary>
public interface IFunctionHolder
{
int ArgCount { get; }
object DynamicInvoke(params object[] args);
}
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder
{
public int ArgCount => 0;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
public TR Invoke()
{
return Func.Invoke();
}
}
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder
{
public int ArgCount => 1;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder
{
public int ArgCount => 2;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder
{
public int ArgCount => 3;
public object DynamicInvoke(params object[] args)
{
return Func.DynamicInvoke(args);
}
}
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assigned = false;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assigned = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assigned = true;
}

public Type DataType => _type;

public TA GetValueA()
{
if(!_assigned || DataType != typeof(TA))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueA;
}
public TB GetValueB()
{
if (!_assigned || DataType != typeof(TB))
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
return _valueB;
}
public object GetValue()
{
if (!_assigned)
{
throw new TypeError("Cannot get the data because of wrong specified type.");
}
if(DataType == typeof(TA) && _valueA is not null)
{
return _valueA;
}
else if(DataType == typeof(TB) && _valueB is not null)
{
return _valueB;
}
else if(DataType == typeof(TA))
{
return _valueA;
}
else
{
return _valueB;
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver
{
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict;
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<Tensor> tensors = new();
List<string> slice_specs = new();
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
var tensor_value = spec.tensor;
if (tensor_value is not null)
{
tensor_names.Add(spec.name);
tensors.Add(tensor_value);
slice_specs.Add(spec.slice_spec);
}
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_names.Add(checkpoint_key);
tensors.Add(tensor);
slice_specs.Add(slice_spec);
}
}
}
// TODO: specify the device.
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray());
}

public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options);

public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<TF_DataType> tensor_dtypes = new();
List<string> slice_specs = new();

foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec);
tensor_names.Add(spec.name);
}
else
{
var tensor = maybe_tensor.GetValueA();
tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key);
}
}
}

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice_spec in tensor_slices.Keys)
{
var restored_tensor = restored_tensors[idx++];
if (!restored_tensor_dict.ContainsKey(checkpoint_key))
{
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor;
}
}
return restored_tensor_dict;
}

public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix));
}
/// <summary>
/// Saves checkpoints directly from multiple devices.
/// Note that this is a low-level utility which stores Tensors in the keys
/// specified by `SaveableObject`s.Higher-level utilities for object-based
/// checkpointing are built on top of it.
/// </summary>
public class MultiDeviceSaver
{
private Dictionary<string, SingleDeviceSaver> _single_device_savers;
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers;
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn;
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys;
/// <summary>
///
/// </summary>
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>();
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
var obj = pair.Key;
var tensor_dict = pair.Value;
IFunctionHolder restore_fn;
if(obj == Trackable.None)
{
restore_fn = new FunctionHolder<object?>(() => null);
}
else
{
restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x =>
{
return obj._restore_from_tensors(x);
});
}

foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.DataType != typeof(IDictionary<string, Tensor>))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = item.Value.GetValueA();
}
else
{
spec_to_tensor = item.Value.GetValueB();
}

foreach(var spec in spec_to_tensor)
{
var slice_spec = spec.Key;
var tensor = spec.Value;
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec)))
{
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " +
$"slice spec. This is invalid because one will overwrite the " +
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation.");
}
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
}
}

_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value));

_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>();
if(registered_savers is not null && registered_savers.Count > 0)
{
// TODO: complete the implementation.
throw new NotImplementedException();
}
}

public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
{
if(options is null)
{
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
constant_op.constant(".part"), constant_op.constant("_temp/part"));
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));

Operation save_fn()
{
List<Tensor> saved_prefixes= new();
foreach(var saver in _registered_savers)
{
// TODO: implementi it later.
throw new NotImplementedException();
}

int num_shards = _single_device_savers.Count;
List<Operation> sharded_saves = new();
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards");
string? last_device = null;
int shard = 0;
foreach(var pair in _single_device_savers.OrderBy(x => x.Key))
{
var device = pair.Key;
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
}
}

if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1)
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return save_fn();
}
}

public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options);

public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}

IDictionary<string, Operation> restore_func()
{
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new();

foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key))
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
}
else
{
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}

foreach(var item in _registered_savers)
{
throw new NotImplementedException();
}
return restore_ops;
}

// TODO: complete the implementation. Currently skip it because of lack of API.
bool has_custom_device_saver = false;

if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver))
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return restore_func();
}
}

/// <summary>
/// Serializes to a SaverDef referencing the current graph.
/// </summary>
public SaverDef to_proto()
{
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename");
var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING);
var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING);
var save_tensor = traced_save_func(filename_tensor);
var restore_op = traced_restore_func(filename_tensor).op;
return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
SaveTensorName = save_tensor.name,
RestoreOpName = restore_op.name,
Version = SaverDef.Types.CheckpointFormatVersion.V2
};
}

private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
{
return array_ops.identity(file_prefix);
}
}

private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(restore_op.Values.ToArray()))
{
return array_ops.identity(file_prefix);
}
}

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach (var saveable in saveables)
{
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });
serialized_tensors[trackable] = trackable.serialize_to_tensors();
}
return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures);
}

private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name)
{
return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") });
}
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards)
{
return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards);
}
}
}

+ 68
- 0
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Tensorflow.Train;

namespace Tensorflow
{
@@ -90,4 +91,71 @@ namespace Tensorflow
Dispose(false);
}
}

public abstract class DisposableTrackableObject: Trackable, IDisposable
{
protected IntPtr _handle;
protected bool _disposed;

protected DisposableTrackableObject()
{ }

protected DisposableTrackableObject(IntPtr handle)
=> _handle = handle;

private void Dispose(bool disposing)
{
if (_disposed)
return;

//first handle managed, they might use the unmanaged resources.
if (disposing)
{
// dispose managed state (managed objects).
DisposeManagedResources();
}

// free unmanaged memory
if (_handle != IntPtr.Zero)
{
// Call the appropriate methods to clean up
// unmanaged resources here.
// If disposing is false,
// only the following code is executed.
DisposeUnmanagedResources(_handle);
_handle = IntPtr.Zero;
}

// Note disposing has been done.
_disposed = true;
}

/// <summary>
/// Dispose any managed resources.
/// </summary>
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks>
protected virtual void DisposeManagedResources()
{ }

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected abstract void DisposeUnmanagedResources(IntPtr handle);

public void Dispose()
{
Dispose(true);
// This object will be cleaned up by the Dispose method.
// Therefore, you should call GC.SupressFinalize to
// take this object off the finalization queue
// and prevent finalization code for this object
// from executing a second time.
GC.SuppressFinalize(this);
}

~DisposableTrackableObject()
{
Dispose(false);
}
}
}

+ 31
- 0
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Contexts;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
internal class execute
{
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{
var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx));
var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray());
}
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
string device_name = ctx.DeviceName;

ctx.ensure_initialized();
var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs);

return tensors;
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Exceptions/AssertionError.cs View File

@@ -0,0 +1,14 @@
namespace Tensorflow.Exceptions;

public class AssertionError : TensorflowException
{
public AssertionError() : base()
{

}

public AssertionError(string message) : base(message)
{

}
}

+ 85
- 1
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -304,7 +304,7 @@ namespace Tensorflow
}
}

private static OpList stripped_op_list_for_graph(GraphDef graph_def)
public static OpList stripped_op_list_for_graph(GraphDef graph_def)
{
var used_ops = ops_used_by_graph_def(graph_def);

@@ -345,5 +345,89 @@ namespace Tensorflow

return used_ops.ToArray();
}

private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value)
{
foreach (var attr_def in op_def.Attr)
{
if (attr_def.Name == attr_name)
{
if (attr_def.DefaultValue is null) return false;
// TODO: add new c_api `EqualAttrValueWrapper` and complete the check.
return true;
}
}

return false;
}

public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def)
{
Dictionary<string, FunctionDef> op_name_to_function = new();
foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
op_name_to_function[function_def.Signature.Name] = function_def;
}

Action<NodeDef> _strip_node_default_valued_attrs = (node_def) =>
{
if (op_name_to_function.ContainsKey(node_def.Op)) return;

var op_def = op_def_registry.GetOpDef(node_def.Op);
if(op_def is null) return;

HashSet<string> attrs_to_strip = new();
foreach (var attr in node_def.Attr)
{
if (is_default_attr_value(op_def, attr.Key, attr.Value))
{
attrs_to_strip.Add(attr.Key);
}
}

foreach (var attr in attrs_to_strip)
{
node_def.Attr.Remove(attr);
}
};

foreach (var node_def in meta_graph_def.GraphDef.Node)
{
_strip_node_default_valued_attrs(node_def);
}

foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
foreach (var function_node_def in function_def.NodeDef)
{
_strip_node_default_valued_attrs(function_node_def);
}
}

meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
}

/// <summary>
/// Extract the Op name from a Tensor name.
/// </summary>
/// <param name="tensor_name"></param>
/// <returns></returns>
public static string op_name(string tensor_name)
{
if (string.IsNullOrEmpty(tensor_name))
{
throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}.");
}

if (tensor_name.StartsWith("^"))
{
tensor_name = tensor_name.Substring(1);
}
if (tensor_name.Contains(":"))
{
return tensor_name.Split(':')[0];
}
return tensor_name;
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework.Models;
using Tensorflow.Graphs;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -10,7 +11,7 @@ namespace Tensorflow.Functions
/// <summary>
///
/// </summary>
public class ConcreteFunction
public class ConcreteFunction: Trackable
{
FuncGraph func_graph;
ForwardBackwardCall forward_backward;


+ 9
- 2
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -1,16 +1,23 @@
using System;
using Tensorflow.Train;

namespace Tensorflow
{
public class Function
public class Function: Trackable
{
#pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used

public string Name { get; set; }
public Function()
{

}
public Function(string name)
{
Name = name;
}
}
}

+ 32
- 14
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;

@@ -6,14 +7,14 @@ namespace Tensorflow.Graphs
{
public class AutoGraph
{
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32)
{
string func_name = $"{func.Method.Name}_{ops.uid_function()}";

var graph = new FuncGraph(func_name);
graph.as_default();

var input = tf.placeholder(tf.int32);
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -26,25 +27,33 @@ namespace Tensorflow.Graphs

return (Tensor input) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
}
using (var s = tf.Session(input.graph))
{
var output = func(input);
return output;
}
};
}

public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes)
{
string func_name = $"{func.Method.Name}_{ops.uid_function()}";

var graph = new FuncGraph(func_name);
graph.as_default();

var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32);
var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32);
var output = func(input1, input2);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -56,13 +65,22 @@ namespace Tensorflow.Graphs
return (Tensor a, Tensor b) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { a, b },
null,
1);
return result[0];
return result[0];
}
using (var s = tf.Session(a.graph))
{
Debug.Assert(a.graph == b.graph);
var output = func(a, b);
return output;
}
};
}
}


+ 13
- 4
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs View File

@@ -1,9 +1,18 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition {
public class SoftmaxArgs : LayerArgs {
public Axis axis { get; set; } = -1;
}
public class SoftmaxArgs : LayerArgs
{
[JsonProperty("axis")]
public Axis axis { get; set; } = -1;
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
}
}

+ 19
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -0,0 +1,19 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class AutoSerializeLayerArgs: LayerArgs
{
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
}

+ 41
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs View File

@@ -1,13 +1,18 @@
using System;
using Newtonsoft.Json;
using System;
using System.Xml.Linq;
using Tensorflow.Operations.Initializers;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: `activity_regularizer`
public class DenseArgs : LayerArgs
{
/// <summary>
/// Positive integer, dimensionality of the output space.
/// </summary>
[JsonProperty("units")]
public int Units { get; set; }

/// <summary>
@@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition
/// </summary>
public Activation Activation { get; set; }

private string _activationName;
[JsonProperty("activation")]
public string ActivationName
{
get
{
if (string.IsNullOrEmpty(_activationName))
{
return Activation.Method.Name;
}
else
{
return _activationName;
}
}
set
{
_activationName = value;
}
}

/// <summary>
/// Whether the layer uses a bias vector.
/// </summary>
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;

/// <summary>
/// Initializer for the `kernel` weights matrix.
/// </summary>
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;

/// <summary>
/// Initializer for the bias vector.
/// </summary>
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;

/// <summary>
/// Regularizer function applied to the `kernel` weights matrix.
/// </summary>
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }

/// <summary>
/// Regularizer function applied to the bias vector.
/// </summary>
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; }

/// <summary>
/// Constraint function applied to the `kernel` weights matrix.
/// </summary>
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; }

/// <summary>
/// Constraint function applied to the bias vector.
/// </summary>
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; }

[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
}

+ 15
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs View File

@@ -1,9 +1,22 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Tensorflow.Keras.Common;

namespace Tensorflow.Keras.ArgsDefinition
{
public class InputLayerArgs : LayerArgs
{
[JsonIgnore]
public Tensor InputTensor { get; set; }
public bool Sparse { get; set; }
[JsonProperty("sparse")]
public virtual bool Sparse { get; set; }
[JsonProperty("ragged")]
public bool Ragged { get; set; }
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
}
}

+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs
public class DataAdapterArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs
public class DataHandlerArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }


+ 17
- 14
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -1,51 +1,54 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
public class LayerArgs
[JsonObject(MemberSerialization.OptIn)]
public class LayerArgs: IKerasConfig
{
/// <summary>
/// Indicates whether the layer's weights are updated during training
/// and whether the layer's updates are run during training.
/// </summary>
public bool Trainable { get; set; } = true;

public string Name { get; set; }
public virtual bool Trainable { get; set; } = true;
public virtual string Name { get; set; }

/// <summary>
/// Only applicable to input layers.
/// </summary>
public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;
public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;

/// <summary>
/// Whether the `call` method can be used to build a TF graph without issues.
/// This attribute has no effect if the model is created using the Functional
/// API. Instead, `model.dynamic` is determined based on the internal layers.
/// </summary>
public bool Dynamic { get; set; } = false;
public virtual bool Dynamic { get; set; } = false;

/// <summary>
/// Only applicable to input layers.
/// </summary>
public Shape InputShape { get; set; }
public virtual Shape InputShape { get; set; }

/// <summary>
/// Only applicable to input layers.
/// </summary>
public Shape BatchInputShape { get; set; }
public virtual Shape BatchInputShape { get; set; }

public int BatchSize { get; set; } = -1;
public virtual int BatchSize { get; set; } = -1;

/// <summary>
/// Initial weight values.
/// </summary>
public float[] Weights { get; set; }
public virtual float[] Weights { get; set; }

/// <summary>
/// Regularizer function applied to the output of the layer(its "activation").
/// </summary>
public IRegularizer ActivityRegularizer { get; set; }
public virtual IRegularizer ActivityRegularizer { get; set; }

public bool Autocast { get; set; }
public virtual bool Autocast { get; set; }

public bool IsFromConfig { get; set; }
public virtual bool IsFromConfig { get; set; }
}
}

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
public class NodeArgs
public class NodeArgs: IKerasConfig
{
public ILayer[] InboundLayers { get; set; }
public int[] NodeIndices { get; set; }


+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
public class OptimizerV2Args
public class OptimizerV2Args: IKerasConfig
{
public string Name { get; set; }
public float LearningRate { get; set; } = 0.001f;


+ 5
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs View File

@@ -1,7 +1,10 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{
public class FlattenArgs : LayerArgs
public class FlattenArgs : AutoSerializeLayerArgs
{
[JsonProperty("data_format")]
public string DataFormat { get; set; }
}
}

+ 50
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs View File

@@ -0,0 +1,50 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedActivationJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Activation);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject("");
token.WriteTo(writer);
}
else if (value is not Activation)
{
throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Activation)!.GetType().Name);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
throw new NotImplementedException();
//var dims = serializer.Deserialize(reader, typeof(string));
//if (dims is null)
//{
// throw new ValueError("Cannot deserialize 'null' to `Activation`.");
//}
//return new Shape((long[])(dims!));
}
}
}

+ 48
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -0,0 +1,48 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedAxisJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Axis);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(new int[] { });
token.WriteTo(writer);
}
else if (value is not Axis)
{
throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Axis)!.axis);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var axis = serializer.Deserialize(reader, typeof(long[]));
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
}
return new Axis((int[])(axis!));
}
}
}

+ 73
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -0,0 +1,73 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Common
{
public class CustomizedNodeConfigJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(NodeConfig);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if (value is not NodeConfig)
{
throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var config = value as NodeConfig;
var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex });
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var values = serializer.Deserialize(reader, typeof(object[])) as object[];
if (values is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length != 3)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
if (values[0] is not string)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
}
if (values[1] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
}
if (values[2] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
}
return new NodeConfig()
{
Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
};
}
}
}

+ 67
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -0,0 +1,67 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedShapeJsonConverter: JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Shape);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if(value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if(value is not Shape)
{
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var shape = (value as Shape)!;
long?[] dims = new long?[shape.ndim];
for(int i = 0; i < dims.Length; i++)
{
if (shape.dims[i] == -1)
{
dims[i] = null;
}
else
{
dims[i] = shape.dims[i];
}
}
var token = JToken.FromObject(dims);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
if(dims is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{
convertedDims[i] = dims[i] ?? (-1);
}
return new Shape(convertedDims);
}
}
}

+ 30
- 1
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -16,23 +16,27 @@

using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// Specifies the ndim, dtype and shape of every input to a layer.
/// </summary>
public class InputSpec
public class InputSpec: IKerasConfigable
{
public int? ndim;
public int? max_ndim;
public int? min_ndim;
Dictionary<int, int> axes;
Shape shape;
TF_DataType dtype;
public int[] AllAxisDim;

public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null,
int? min_ndim = null,
int? max_ndim = null,
Dictionary<int, int> axes = null,
Shape shape = null)
{
@@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine
axes = new Dictionary<int, int>();
this.axes = axes;
this.min_ndim = min_ndim;
this.max_ndim = max_ndim;
this.shape = shape;
this.dtype = dtype;
if (ndim == null && shape != null)
this.ndim = shape.ndim;

@@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine
AllAxisDim = axes.Select(x => x.Value).ToArray();
}

public IKerasConfig get_config()
{
return new Config()
{
DType = dtype == TF_DataType.DtInvalid ? null : dtype,
Shape = shape,
Ndim = ndim,
MinNdim = min_ndim,
MaxNdim = max_ndim,
Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value)
};
}

public override string ToString()
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";

public class Config: IKerasConfig
{
public TF_DataType? DType { get; set; }
public Shape Shape { get; set; }
public int? Ndim { get; set; }
public int? MinNdim { get;set; }
public int? MaxNdim { get;set; }
public IDictionary<string, int> Axes { get; set; }
}
}
}

+ 4
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,10 +1,12 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Training;

namespace Tensorflow.Keras
{
public interface ILayer
public interface ILayer: IWithTrackable, IKerasConfigable
{
string Name { get; }
bool Trainable { get; }
@@ -19,8 +21,8 @@ namespace Tensorflow.Keras
List<IVariableV1> NonTrainableWeights { get; }
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
TF_DataType DType { get; }
int count_params();
LayerArgs get_config();
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public interface IKerasConfig
{
}

public interface IKerasConfigable
{
IKerasConfig get_config();
}
}

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs View File

@@ -1,4 +1,5 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
@@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving
{
public class LayerConfig
public class LayerConfig: IKerasConfig
{
[JsonProperty("name")]
public string Name { get; set; }
[JsonProperty("class_name")]
public string ClassName { get; set; }
[JsonProperty("config")]
public LayerArgs Config { get; set; }
[JsonProperty("inbound_nodes")]
public List<NodeConfig> InboundNodes { get; set; }
}
}

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,15 +1,20 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving
{
public class ModelConfig
public class ModelConfig : IKerasConfig
{
[JsonProperty("name")]
public string Name { get; set; }
[JsonProperty("layers")]
public List<LayerConfig> Layers { get; set; }
[JsonProperty("input_layers")]
public List<NodeConfig> InputLayers { get; set; }
[JsonProperty("output_layers")]
public List<NodeConfig> OutputLayers { get; set; }

public override string ToString()


+ 5
- 2
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -1,10 +1,13 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Common;

namespace Tensorflow.Keras.Saving
{
public class NodeConfig
[JsonConverter(typeof(CustomizedNodeConfigJsonConverter))]
public class NodeConfig : IKerasConfig
{
public string Name { get; set; }
public int NodeIndex { get; set; }


+ 35
- 0
src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel
{
public interface ISerializedAttributes
{
IDictionary<string, Trackable> Functions { get; }

IDictionary<string, Trackable> CheckpointableObjects { get; }

/// <summary>
/// Returns functions to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> FunctionsToSerialize { get; }

/// <summary>
/// Returns objects to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> ObjectsToSerialize{get; }

/// <summary>
/// Saves function dictionary, and validates dictionary values.
/// </summary>
/// <param name="function_dict"></param>
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict);

/// <summary>
/// Saves objects to a dictionary, and validates the values.
/// </summary>
/// <param name="object_dict"></param>
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict);
}
}

+ 21
- 0
src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs View File

@@ -0,0 +1,21 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class TensorShapeConfig
{
[JsonProperty("class_name")]
public string ClassName { get; set; } = "TensorShape";
[JsonProperty("items")]
public long?[] Items { get; set; }

public static implicit operator Shape(TensorShapeConfig shape)
=> shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());

public static implicit operator TensorShapeConfig(Shape shape)
=> new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() };
}
}

+ 43
- 1
src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs View File

@@ -9,10 +9,52 @@ namespace Tensorflow.ModelSaving
/// </summary>
public class SaveOptions
{
bool save_debug_info;
public bool save_debug_info = false;
public IList<string>? namespace_white_list { get; set; } = null;
public IDictionary<string, object>? function_aliases { get; set; } = null;
public string? experimental_io_device { get; set; } = null;
// TODO: experimental
public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None;
public bool experimental_custom_gradients { get; set; } = true;
public SaveOptions(bool save_debug_info = false)
{
this.save_debug_info = save_debug_info;
}
}

public class VariablePolicy
{
public string Policy { get; }
private VariablePolicy(string policy)
{
Policy = policy;
}
public static VariablePolicy None = new(null);
public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices");
public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables");

public bool save_variable_devices()
{
return this != VariablePolicy.None;
}

/// <summary>
/// Tries to convert `obj` to a VariablePolicy instance.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public static VariablePolicy from_obj(object obj)
{
if (obj is null) return VariablePolicy.None;
if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower();
return key switch
{
null => VariablePolicy.None,
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
};
}
}
}

+ 10
- 1
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -14,20 +14,29 @@
limitations under the License.
******************************************************************************/

using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Common;

namespace Tensorflow
{
public record Axis(params int[] axis)
[JsonConverter(typeof(CustomizedAxisJsonConverter))]
public class Axis
{
public int[] axis { get; set; }
public int size => axis == null ? -1 : axis.Length;
public bool IsScalar { get; init; }

public int this[int index] => axis[index];

public Axis(params int[] axis)
{
this.axis = axis;
}

public static implicit operator int[]?(Axis axis)
=> axis?.axis;



+ 3
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -14,14 +14,17 @@
limitations under the License.
******************************************************************************/

using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.NumPy;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
{
public int ndim => _dims == null ? -1 : _dims.Length;


+ 10
- 0
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class Constant<T> : IInitializer
@@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers
T value;
bool _verify_shape;

private readonly Dictionary<string, object> _config;

public string ClassName => "Constant";
public IDictionary<string, object> Config => _config;

public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
{
this.value = value;
this.dtype = dtype;
_verify_shape = verify_shape;

_config = new Dictionary<string, object>();
_config["value"] = this.value;
}

public Tensor Apply(InitializerArgs args)


+ 9
- 1
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -14,10 +14,17 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class GlorotUniform : VarianceScaling
{
private readonly Dictionary<string, object> _config;

public override string ClassName => "GlorotUniform";
public override IDictionary<string, object> Config => _config;

public GlorotUniform(float scale = 1.0f,
string mode = "FAN_AVG",
bool uniform = true,
@@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers
seed: seed,
dtype: dtype)
{

_config = new Dictionary<string, object>();
_config["seed"] = _seed;
}
}
}

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -14,10 +14,17 @@
limitations under the License.
******************************************************************************/

using Newtonsoft.Json;
using System.Collections.Generic;

namespace Tensorflow
{
public interface IInitializer
{
[JsonProperty("class_name")]
string ClassName { get; }
[JsonProperty("config")]
IDictionary<string, object> Config { get; }
Tensor Apply(InitializerArgs args);
}
}

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -14,12 +14,19 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class Ones : IInitializer
{
private TF_DataType dtype;

private readonly Dictionary<string, object> _config;

public string ClassName => "Ones";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;


+ 6
- 1
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2023 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,6 +19,7 @@ using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Operations.Initializers;
using System.Collections.Generic;

public class Orthogonal : IInitializer
{
@@ -31,6 +32,10 @@ public class Orthogonal : IInitializer
_seed = seed;
}

private readonly Dictionary<string, object> _config;

public string ClassName => "Orthogonal";
public IDictionary<string, object> Config => throw new NotImplementedException();
public Tensor Apply(InitializerArgs args)
{
return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType);


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class RandomNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed;
private TF_DataType dtype;

private readonly Dictionary<string, object> _config;

public string ClassName => "RandomNormal";
public IDictionary<string, object> Config => _config;

public RandomNormal(float mean = 0.0f,
float stddev = 0.05f,
int? seed = null,
@@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;

_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
}

public Tensor Apply(InitializerArgs args)


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class RandomUniform : IInitializer
@@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers
private float maxval;
private TF_DataType dtype;

private readonly Dictionary<string, object> _config;

public string ClassName => "RandomUniform";
public IDictionary<string, object> Config => _config;

public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null)
{
this.dtype = dtype;
this.minval = minval;
this.maxval = maxval;
this.seed = seed;

_config = new Dictionary<string, object>();
_config["minval"] = this.minval;
_config["maxval"] = this.maxval;
_config["seed"] = this.seed;
}

public Tensor Apply(InitializerArgs args)


+ 11
- 0
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class TruncatedNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed;
private TF_DataType dtype;

private readonly Dictionary<string, object> _config;

public string ClassName => "TruncatedNormal";
public IDictionary<string, object> Config => _config;

public TruncatedNormal(float mean = 0.0f,
float stddev = 1.0f,
int? seed = null,
@@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;
_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
}

public Tensor Apply(InitializerArgs args)


+ 13
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -15,7 +15,9 @@
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;

namespace Tensorflow.Operations.Initializers
{
@@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers
protected int? _seed;
protected TF_DataType _dtype;
protected bool _uniform;
private readonly Dictionary<string, object> _config;

public virtual string ClassName => "VarianceScaling";

public virtual IDictionary<string, object> Config => _config;

public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN",
@@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers
_seed = seed;
_dtype = dtype;
_uniform = uniform;

_config = new();
_config["scale"] = _scale;
_config["mode"] = _mode;
_config["distribution"] = _distribution;
_config["seed"] = _seed;
}

public Tensor Apply(InitializerArgs args)


+ 5
- 0
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers
{
public class Zeros : IInitializer
@@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers
Shape shape;
TF_DataType dtype;

public string ClassName => "Zeros";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.shape = shape;


+ 7
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -20,7 +20,9 @@ using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
using static Tensorflow.Binding;

@@ -75,6 +77,8 @@ namespace Tensorflow

public Shape BatchInputShape => throw new NotImplementedException();

public TensorShapeConfig BuildInputShape => throw new NotImplementedException();

public TF_DataType DType => throw new NotImplementedException();
protected bool built = false;
public bool Built => built;
@@ -143,7 +147,7 @@ namespace Tensorflow
throw new NotImplementedException();
}

public LayerArgs get_config()
public IKerasConfig get_config()
{
throw new NotImplementedException();
}
@@ -152,5 +156,7 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public Trackable GetTrackable() { throw new NotImplementedException(); }
}
}

+ 55
- 4
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -1,6 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
@@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations
/// path in the input checkpoint_prefixes. This is useful when those paths are non
/// user-facing temporary locations.
/// </remarks>
public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints")
{
public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files));
result = null;
return null;
//try
//{
// var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
// new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }));
// result = null;
// return null;
//}
//catch (System.Exception)
//{
// return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs,
// allow_missing_files: allow_missing_files, name: name, ctx: ctx);
//}
}
var dict = new Dictionary<string, object>();
dict["checkpoint_prefixes"] = checkpoint_prefixes;
dict["destination_prefix"] = destination_prefix;
if (delete_old_dirs.HasValue)
dict["delete_old_dirs"] = delete_old_dirs.Value;
dict["delete_old_dirs"] = delete_old_dirs;
var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict);
return op;
}

//public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx)
//{
// checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING);
// destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING);
// var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix };
// var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files };
// var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name);
// result = null;
// return null;
//}

/// <summary>
/// Transforms a spectrogram into a form that's useful for speech recognition.
/// </summary>
@@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["input"] = input;
dict["pattern"] = pattern;
@@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["basename"] = basename;
dict["shard"] = shard;
@@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["inputs"] = inputs;
if (separator != null)


+ 32
- 0
src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -23,11 +25,41 @@ namespace Tensorflow
{
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(
new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors }));
result = null;
return null;
}
catch (System.Exception)
{
return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx);
}
}
var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });

return _op;
}

public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx)
{
DataType[] attr_dtypes;
(attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx);
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING);
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING);
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING);
var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray();
var attrs = new object[] { "dtypes", attr_dtypes };

var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name);
result = null;
return null;
}

public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


+ 60
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,6 +17,9 @@
using System;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
@@ -38,6 +41,11 @@ namespace Tensorflow
{
return var is ResourceVariable;
}
public static bool is_resource_variable(Trackable var)
{
return var is BaseResourceVariable;
}

/// <summary>
/// Creates a variable handle with information to do shape inference.
@@ -171,5 +179,57 @@ namespace Tensorflow
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
}

/// <summary>
/// Copies an existing variable to a new graph, with no initializer.
/// </summary>
/// <param name="variable"></param>
public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable)
{
var new_variable = new UninitializedVariable(
trainable: variable.Trainable,
shape: variable.shape,
dtype: variable.dtype,
name: variable.SharedName,
aggregation: variable.Aggregation,
extra_handle_data: null);
new_variable._maybe_initialize_trackable();
return new_variable;
}

/// <summary>
/// Writes additional information of the variable into the SavedObject proto.
/// </summary>
/// <param name="resource_variable"></param>
/// <param name="proto"></param>
/// <param name="options"></param>
/// <param name="enforcing_naming"></param>
public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true)
{
// lack of API: `proto.Variable.SetInParent()`.
if(enforcing_naming && !resource_variable.Name.EndsWith(":0"))
{
throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " +
$"unexpected suffix in the name (expected ':0') which won't be restored.");
}
if(proto.Variable is null)
{
proto.Variable = new SavedVariable();
}
proto.Variable.Name = meta_graph.op_name(resource_variable.Name);
proto.Variable.Trainable = resource_variable.Trainable;
proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum();
// TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`.
proto.Variable.Aggregation = resource_variable.Aggregation;
proto.Variable.Shape = resource_variable.shape.as_proto();

if (options.experimental_variable_policy.save_variable_devices())
{
if (!string.IsNullOrEmpty(resource_variable.Device))
{
proto.Variable.Device = resource_variable.Device;
}
}
}
}
}

+ 9
- 1
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -156,7 +156,7 @@ namespace Tensorflow {
/// Nodes[0] is considered the root node.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
get { return nodes_; }
}

@@ -286,6 +286,7 @@ namespace Tensorflow {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(SavedObject other) : this() {
children_ = other.children_.Clone();
dependencies_ = other.dependencies_.Clone();
slotVariables_ = other.slotVariables_.Clone();
saveableObjects_ = other.saveableObjects_.Clone();
switch (other.KindCase) {
@@ -328,6 +329,7 @@ namespace Tensorflow {
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
/// <summary>
/// Objects which this object depends on: named edges in the dependency
/// graph.
@@ -338,6 +340,11 @@ namespace Tensorflow {
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children {
get { return children_; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies {
get { return dependencies_; }
}

/// <summary>Field number for the "slot_variables" field.</summary>
public const int SlotVariablesFieldNumber = 3;
@@ -617,6 +624,7 @@ namespace Tensorflow {
return;
}
children_.Add(other.children_);
dependencies_.Add(other.dependencies_);
slotVariables_.Add(other.slotVariables_);
saveableObjects_.Add(other.saveableObjects_);
switch (other.KindCase) {


+ 16
- 0
src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs View File

@@ -198,6 +198,22 @@ namespace Tensorflow {
public TrackableObject() {
OnConstruction();
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) {
OnConstruction();
slotVariables_ = slot;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot,
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children
)
{
OnConstruction();
slotVariables_ = slot;
children_ = children;
}

partial void OnConstruction();



+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="Protobuf.Text" Version="0.6.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>


+ 18
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -202,6 +202,24 @@ namespace Tensorflow
_ => type.ToString()
};

public static string as_python_name(this TF_DataType type)
=> type switch
{
TF_DataType.TF_STRING => "str",
TF_DataType.TF_UINT8 => "uint8",
TF_DataType.TF_INT8 => "int8",
TF_DataType.TF_UINT32 => "uint32",
TF_DataType.TF_INT32 => "int32",
TF_DataType.TF_UINT64 => "uint64",
TF_DataType.TF_INT64 => "int64",
TF_DataType.TF_FLOAT => "float32",
TF_DataType.TF_DOUBLE => "float64",
TF_DataType.TF_BOOL => "bool",
TF_DataType.TF_RESOURCE => "resource",
TF_DataType.TF_VARIANT => "variant",
_ => type.ToString()
};

public static int get_datatype_size(this TF_DataType type)
=> type.as_base_dtype() switch
{


+ 67
- 2
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -1,6 +1,71 @@
namespace Tensorflow.Train
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
public abstract class AutoTrackable : Trackable
public class AutoTrackable : Trackable
{
public void _delete_tracking(string name)
{
_maybe_initialize_trackable();
if (_unconditional_dependency_names.ContainsKey(name))
{
_unconditional_dependency_names.Remove(name);
for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--)
{
if (_unconditional_checkpoint_dependencies[i].Name == name)
{
_unconditional_checkpoint_dependencies.RemoveAt(i);
}
}
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type != SaveType.SAVEDMODEL)
{
return base._trackable_children(save_type, cache);
}

Dictionary<string, Trackable> functions = new();
// TODO: process of logs.
var properties = this.GetType().GetProperties();
foreach ( var property in properties )
{
string name = property.Name;
object value = property.GetValue(this, null);
if(value is Function || value is ConcreteFunction)
{
functions[name] = (Trackable)value;
}
}

// TODO: process the type `core_types.GenericFunction`.

Dictionary<string, Trackable> children = new();
foreach(var pair in CheckpointDependencies)
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction) // or Generic function
{
continue;
}
if(functions.ContainsKey(name) && functions[name] != child)
{
throw new ValueError($"Can't save object because it has multiple children with the same " +
$"name. Object: {this}, attribute name: {name}, child 1: " +
$"{child}, child 2: {functions[name]}");
}
children[name] = child;
}

return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value);
}
}
}

+ 12
- 0
src/TensorFlowNET.Core/Training/IWithTrackable.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{
public interface IWithTrackable
{
Trackable GetTrackable();
}
}

+ 9
- 0
src/TensorFlowNET.Core/Training/LayerUtils.cs View File

@@ -0,0 +1,9 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{

}

+ 6
- 1
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -351,7 +351,7 @@ namespace Tensorflow
/// <param name="var"></param>
/// <param name="name"></param>
/// <returns></returns>
protected IVariableV1 get_slot(IVariableV1 var, string name)
internal IVariableV1 get_slot(IVariableV1 var, string name)
{
var named_slots = _slots.ContainsKey(name) ? _slots[name] : null;
if (named_slots == null)
@@ -360,6 +360,11 @@ namespace Tensorflow
return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null;
}

internal IEnumerable<string> get_slot_names()
{
return _slots.Keys;
}

private string _var_key(IVariableV1 var)
{
return $"{var.Op.graph.graph_key}.{var.Op.name}";


+ 28
- 0
src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using static Tensorflow.Binding;

namespace Tensorflow
{
public class ResourceVariableSaveable : MySaveableObject
@@ -35,6 +37,32 @@ namespace Tensorflow
this.name = name;
}

public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name)
{
_var_device = var.Device;
_var_shape = var.shape;

Tensor _read_variable_closure(BaseResourceVariable v)
{
tf.device(v.Device);
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
}

this.handle_op = var.Handle;
var tensor = _read_variable_closure(var);

var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype);
_op = var;
specs = new SaveSpec[] { spec };
this.name = name;
}

public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
public string slice_spec => _slice_spec;

private string _name;
public string name => _name;
public string name { get => _name; set => _name = value; }

private TF_DataType _dtype;
public TF_DataType dtype => _dtype;


+ 36
- 2
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -14,11 +14,31 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Checkpoint;

namespace Tensorflow
{
public class MySaveableObject
{
public Tensor op;
protected Maybe<Tensor, BaseResourceVariable> _op;
public Tensor op
{
get
{
if(_op.DataType == typeof(Tensor))
{
return _op.GetValueA();
}
else
{
throw new TypeError("The _op is not a tensor.");
}
}
set
{
_op = value;
}
}
public SaveSpec[] specs;
public string name;
public string device;
@@ -35,7 +55,7 @@ namespace Tensorflow

public MySaveableObject(Tensor op, SaveSpec[] specs, string name)
{
this.op = op;
this._op = op;
this.specs = specs;
this.name = name;
}
@@ -48,4 +68,18 @@ namespace Tensorflow
validate_shape: restored_shapes == null && op.shape.IsFullyDefined);
}
}

public class NoRestoreSaveable: MySaveableObject
{
public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor,
new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name)
{
}

public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
return control_flow_ops.no_op();
}
}
}

+ 11
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs View File

@@ -0,0 +1,11 @@
using System.Collections.Generic;

namespace Tensorflow;

public record class AssetInfo
(
List<AssetFileDef> asset_defs,
Dictionary<object, object> asset_initializers_by_resource,
Dictionary<AssetInfo, string> asset_filename_map,
Dictionary<object, object> asset_index
);

+ 133
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -0,0 +1,133 @@
using System;
using Tensorflow.Checkpoint;
using Tensorflow.Train;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow;

public class AugmentedGraphView: ObjectGraphView
{
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache;
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache;
private List<string> _untraces_functions;
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions;
public AugmentedGraphView(Trackable root): base(root)
{
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>();
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>();
_untraces_functions = new List<string>();
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>();
}

public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions)
{
list_children(Root);
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME;
if (!_children_cache.ContainsKey(Root))
{
_children_cache[Root] = new Dictionary<string, Trackable>();
}
_children_cache[Root][name] = signature_map;
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value);
}
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
if(serialization_cache is not null)
{
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`.");
}

if (!_children_cache.ContainsKey(obj))
{
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>();
_children_cache[obj] = children;
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache))
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction)
{
child = maybe_uncache_variable_captures((ConcreteFunction)child);
}
children[name] = child;
}

if (obj is Function && children.Count == 0)
{
_untraces_functions.Add(((Function)obj).Name);
}
}

List<TrackableReference> res = new();
foreach(var pair in _children_cache[obj])
{
res.Add(new TrackableReference(pair.Key, pair.Value));
}

return res;
}

private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function)
{
if (_wrapped_functions.ContainsKey(concrete_function))
{
return _wrapped_functions[concrete_function];
}
// skip the process here because of lack of feature.
// In the future, we may add an attribute which could specify if the variable is supposed to be cached.
//foreach(var capture in concrete_function.CapturedInputs)
//{

//}
return concrete_function;
}

public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
Trackable get_merged_trackable(Trackable x)
{
// TODO: complete it with new definitions `Asset` and `TrackableConstant`.
return x;
}
var trackable_objects = base.breadth_first_traversal();

foreach(var obj in _children_cache.Keys)
{
// skip the deletion of cache (maybe do it later).
foreach(var pair in _children_cache[obj])
{
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
}
}

return base.breadth_first_traversal();
}

public List<(string, Trackable)> list_dependencies(Trackable obj)
{
IDictionary<string, Trackable> children;
if (!_children_cache.ContainsKey(obj))
{
children= new Dictionary<string, Trackable>();
}
else
{
children= _children_cache[obj];
}
List<(string, Trackable)> res = new();
foreach(var pair in obj.deserialization_dependencies(children))
{
res.Add((pair.Key, pair.Value));
}
return res;
}

public Trackable get_child(Trackable obj, string name)
{
return _children_cache[obj][name];
}
}

+ 33
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs View File

@@ -0,0 +1,33 @@
namespace Tensorflow;

public static class Constants
{
public static readonly string ASSETS_DIRECTORY = "assets";
public static readonly string ASSETS_KEY = "saved_model_assets";

public static readonly string DEBUG_DIRECTORY = "debug";

public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb";

public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra";

public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb";

public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";

public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op";

public static readonly string MAIN_OP_KEY = "saved_model_main_op";

public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb";
public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt";

public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1;

public static readonly string TRAIN_OP_KEY = "saved_model_train_op";

public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op";

public static readonly string VARIABLES_DIRECTORY = "variables";
public static readonly string VARIABLES_FILENAME = "variables";
}

+ 17
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -0,0 +1,17 @@
using Tensorflow.Train;

namespace Tensorflow;

public class RevivedTypes
{
/// <summary>
/// Create a SavedUserObject from a trackable object.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public static SavedUserObject? serialize(Trackable obj)
{
// TODO: complete the implementation.
return null;
}
}

+ 9
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs View File

@@ -0,0 +1,9 @@
using System;

namespace Tensorflow;

public enum SaveType
{
SAVEDMODEL,
CHECKPOINT
}

+ 299
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -0,0 +1,299 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow;

public class SaveableView
{
private AugmentedGraphView _augmented_graph_view;
private SaveOptions _options;
private List<Trackable> _trackable_objects;
private List<Trackable> _nodes;
private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths;
private Dictionary<Trackable, int> _node_ids;
private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
_slot_variables;
private Dictionary<Trackable, string> _object_names;
private List<object> _gradient_functions; // to be completed
private List<RegisteredGradient> _gradient_defs; // to be completed
private List<ConcreteFunction> _concrete_functions;
private Dictionary<Tensor, int> _captured_tensor_node_ids;
private Dictionary<Trackable, IDictionary<string, ConcreteFunction>> _saveable_objects_map;
private Dictionary<Trackable, string> _obj_to_registered_saver;

public AugmentedGraphView AugmentedGraphView
{
get => _augmented_graph_view;
}
public Trackable Root
{
get => _nodes[0];
}
public List<Trackable> Nodes
{
get => _nodes;
}
public Dictionary<Trackable, int> NodeIds
{
get => _node_ids;
}
public List<RegisteredGradient> GradientDefs
{
get => _gradient_defs;
}
public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths
{
get => _node_paths;
}
public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options)
{
_augmented_graph_view = augmented_graph_view;
_options = options;

(_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) =
CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view);
// TODO: deal with untraced functions.
initialize_save_and_restore_functions();
initialize_nodes_and_concrete_functions();

_captured_tensor_node_ids = new();
}

private void initialize_save_and_restore_functions()
{
// TODO: deal with the return value of `get_checkpoint_factories_and_keys`.
var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names);
// skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver.
_obj_to_registered_saver = new();
_saveable_objects_map = new();
}

private void initialize_nodes_and_concrete_functions()
{
_nodes = _trackable_objects.ConvertAll(x => x); // deep copy
_gradient_functions = new();
_gradient_defs = new();

// TODO: deal with the condition that obj in `_saveable_objects_map`.
// foreach (var obj in _nodes)
// {
//
// }

foreach (var obj in _nodes)
{
if (obj is ConcreteFunction)
{
_concrete_functions.Add((ConcreteFunction)obj);
}
}
}

public List<ConcreteFunction> get_concrete_resource_initializers()
{
// TODO: complete the implementation.
return new List<ConcreteFunction>();
}
public (Dictionary<Trackable, Trackable>, Dictionary<Tensor, Tensor>, AssetInfo) map_resources()
{
Debug.Assert(!tf.Context.executing_eagerly());

Dictionary<Trackable, Trackable> object_map = new();
Dictionary<Tensor, Tensor> tensor_map = new();

AssetInfo assetInfo = new(new List<AssetFileDef>(), new Dictionary<object, object>(),
new Dictionary<AssetInfo, string>(), new Dictionary<object, object>());

foreach (var node_id in dependency_sorted_node_ids())
{
var obj = _nodes[node_id];
var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options);
// TODO: deal with Asset (if obj is Asset)
foreach (var tensor in tensors)
{
_captured_tensor_node_ids[tensor] = node_id;
}
}

return (object_map, tensor_map, assetInfo);
}

/// <summary>
/// Returns topologically sorted nodes, sorted by dependencies.
/// </summary>
public List<int> dependency_sorted_node_ids()
{
Dictionary<int, IEnumerable<int>> dependency_map = new();
foreach (var node in _nodes)
{
var node_id = _node_ids[node];
List<int> deps = new List<int>();
dependency_map.Add(node_id, deps);
// TODO: deal with captured tensor.

foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node))
{
if (!_node_ids.ContainsKey(dep))
{
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
throw new ValueError(
$"Found an untracked dependency. Object {node_path} depends on {dep}, " +
$"but this dependency isn't listed as a child. Please track this child by " +
$"overriding `_trackable_children` or use `._track_trackable`.");
}
deps.Add(_node_ids[dep]);
}
}

try
{
return TrackableUtils.order_by_dependency(dependency_map);
}
catch (TrackableUtils.CyclicDependencyError err)
{
List<string> pretty_printed_nodes = new();
List<string> pretty_printed_dependencies = new();

foreach (var pair in err.LeftOverDependencyMap)
{
var x = pair.Key;
var deps = pair.Value;
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]);
pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})");
pretty_printed_dependencies.Add(
$"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]");
}

throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " +
$"Saving cannot continue until this cycle is resolved." +
$"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" +
$"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}");
}
}

/// <summary>
/// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph
/// </summary>
/// <param name="asset_index"></param>
/// <returns></returns>
public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index)
{
SavedObjectGraph proto = new();
fill_object_graph_proto(proto);
// TODO: complete the process of concrete functions.

int cnt = Math.Min(_nodes.Count, proto.Nodes.Count);
for (int i = 0; i < cnt; i++)
{
var obj = _nodes[i];
var obj_proto = proto.Nodes[i];
write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x));
}

return proto;
}

private static void write_object_proto(Trackable obj, SavedObject proto,
IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn)
{
// skip the process of type Asset
if (resource_variable_ops.is_resource_variable(obj))
{
var options = SaveContext.get_save_options();
(obj as BaseResourceVariable).write_object_proto(proto, options);
}
else if (obj is Function)
{
// TODO: complete it.
throw new NotImplementedException();
}
else if (obj is ConcreteFunction)
{
// TODO: complete it.
throw new NotImplementedException();
}
// skip the process of type `_CapturedTensor` and `CapturableResource`.
else
{
var registered_type_proto = RevivedTypes.serialize(obj);
if (registered_type_proto is null)
{
registered_type_proto = new SavedUserObject()
{
Identifier = obj.ObjectIdentifier,
Version = new VersionDef()
{
Producer = 1,
MinConsumer = 1,
BadConsumers = { }
}
};
}

proto.UserObject = new SavedUserObject(registered_type_proto);
}
// TODO: try get the registered_name from `registration`.
}

public void fill_object_graph_proto(SavedObjectGraph proto)
{
for (int node_id = 0; node_id < _nodes.Count; node_id++)
{
var node = _nodes[node_id];
Debug.Assert(_node_ids[node] == node_id);
SavedObject object_proto = new();
if (_slot_variables.TryGetValue(node, out var value))
{
object_proto.SlotVariables.AddRange(value);
}
// skip the check of type `_CapturedTensor`
foreach (var child in _augmented_graph_view.list_children(node))
{
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference();
child_proto.NodeId = _node_ids[child.Refer];
child_proto.LocalName = child.Name;
object_proto.Children.Add(child_proto);
}

foreach (var pair in _augmented_graph_view.list_dependencies(node))
{
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference();
child_proto.NodeId = _node_ids[pair.Item2];
child_proto.LocalName = pair.Item1;
object_proto.Dependencies.Add(child_proto);
}

if (_saveable_objects_map.ContainsKey(node))
{
// TODO: complete it.
throw new NotImplementedException();
}
else if(_obj_to_registered_saver.ContainsKey(node))
{
// TODO: complete it.
// We now skip it for the lack of `SavedObject.registered_saver` API.
throw new NotImplementedException();
}

proto.Nodes.Add(object_proto);
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs View File

@@ -0,0 +1,10 @@
namespace Tensorflow;

public static class TagConstants
{
public static readonly string SERVING = "serve";
public static readonly string TRAINING = "train";
public static readonly string EVAL = "eval";
public static readonly string GPU = "gpu";
public static readonly string TPU = "tpu";
}

+ 22
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow;

public class BuilderUtils
{
public static void copy_assets_to_destination_dir(IDictionary<AssetInfo, string> asset_filename_map,
string destination_dir, HashSet<string>? saved_files = null)
{
if (saved_files is null) saved_files = new HashSet<string>();

var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir);

// TODO: complete the implementation of this function.
if (asset_filename_map is not null && asset_filename_map.Count > 0)
{
throw new NotImplementedException();
}
}
}

+ 269
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -0,0 +1,269 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.Protobuf;
using Tensorflow.Checkpoint;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow;

public static partial class SavedModelUtils
{
private static readonly IEnumerable<int> byte_swappable = new List<TF_DataType>()
{
dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16,
dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32,
dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16,
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32
}.Select(x => (int)x);
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj,
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
{
if (options is null)
{
options = new SaveOptions();
}

var saved_model = new Tensorflow.SavedModel();
var meta_graph_def = new MetaGraphDef();
saved_model.MetaGraphs.Add(meta_graph_def);

var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) =
_build_meta_graph(obj, signatures, options, meta_graph_def);
saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION;

if (!experimental_skip_checkpoint)
{
SavedModelUtils.get_or_create_variables_dir(export_dir);
CheckpointOptions ckpt_options = new(options.experimental_io_device);
object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options);
}
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir);

if (tf.Context.executing_eagerly())
{
// tensorflow python has a check of `context.async_wait()` here.
}
// TODO: deal with `pywrap_saved_model.Save(export_dir)`.

var saved_model_serialized = saved_model.ToString();

// This is a state depending on some py-c APIs. Here we temporarily set it as `true`.
if (true)
{
var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir),
tf.compat.as_str(Constants.FINGERPRINT_FILENAME));
// TODO: add c api and complete the fingerprint def.
var fingerprint_proto = "";
File.WriteAllText(fingerprint_path, fingerprint_proto);
}

var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB));
File.WriteAllBytes(path, saved_model.ToByteArray());
//File.WriteAllText(path, saved_model.ToString());

if (options.save_debug_info)
{
throw new NotImplementedException();
}
ops.dismantle_graph(exported_graph);

return (saved_nodes, node_paths);
}

private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>,
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{
using (SaveContext.save_context(options))
{
if (ops.inside_function())
{
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " +
"Move the call to the outer eagerly-executed context.");
}

if (meta_graph_def is null)
{
meta_graph_def = new MetaGraphDef();
}

AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is null)
{
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view);
}

// TODO: process of aignatures and wrapped_functions

SaveableView saveable_view = new SaveableView(augmented_graph_view, options);
TrackableSaver object_saver = new TrackableSaver(augmented_graph_view);
var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures,
options.namespace_white_list, options.experimental_custom_gradients);
if (options.function_aliases is not null)
{
var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases;
foreach (var pair in options.function_aliases)
{
var alias = pair.Key;
var func = pair.Value;
// TODO: complete it.
throw new NotImplementedException();
}
}

var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index);
meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto);

return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths);
}
}

private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view,
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist,
bool save_custom_gradients)
{
var resource_initializers = saveable_view.get_concrete_resource_initializers();
var exported_graph = new Graph();

Dictionary<Trackable, Trackable> object_map;
Dictionary<Tensor, Tensor> tensor_map;
AssetInfo asset_info;
var g = exported_graph.as_default();
(object_map, tensor_map, asset_info) = saveable_view.map_resources();
// TODO: deal with signatures.
if (save_custom_gradients)
{
// TODO: trace gradient functions.
}

foreach (var resource_initializer_function in resource_initializers)
{
// List<Trackable> asset_dependencies = new();
// TODO: deal with initializers
}

// using(ops.control_dependencies(...))
var init_op = control_flow_ops.no_op();
if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY))
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name);
}
else
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef();
}
// Lack `CopyFrom` API
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY]

g.Exit();

foreach (var obj in object_map.Values)
{
obj._maybe_initialize_trackable();
}

// TODO: add the implementation of `call_with_mapped_functions`.
var (named_saveable_objects, registered_savers) =
SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false);
var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false);

var eg = exported_graph.as_default();
var saver_def = saver.to_proto();
meta_graph_def.SaverDef = saver_def;
eg.Exit();


saveable_view.dependency_sorted_node_ids();

var graph_def = exported_graph.as_graph_def(true);
graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs);
verify_ops(graph_def, namespace_whitelist);

meta_graph_def.GraphDef = new GraphDef(graph_def);
meta_graph_def.MetaInfoDef = new();
meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING);
meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION;
// TODO: add git version.
meta_graph_def.MetaInfoDef.TensorflowGitVersion = "";
meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
meta_graph_def.MetaInfoDef.StrippedOpList = new();
meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef));
meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs);
// TODO: deal with signatures here.
meta_graph.strip_graph_default_valued_attrs(meta_graph_def);

if (!BitConverter.IsLittleEndian)
{
swap_function_tensor_content(meta_graph_def);
}

return (asset_info, exported_graph);
}

private static void verify_ops(GraphDef graph_def, IEnumerable<string>? namespace_whitelist)
{
return;
// if (namespace_whitelist is null || !namespace_whitelist.Any())
// {
// return;
// }
// skip the check for the lack of `meta_graph.ops_used_by_graph_def`.
}

public static void swap_function_tensor_content(MetaGraphDef meta_graph_def)
{
var functions = meta_graph_def.GraphDef.Library.Function;
foreach (var function in functions)
{
var node_def = function.NodeDef;
foreach (var node in node_def)
{
if (node.Op == "Const")
{
var tensor = node.Attr["value"].Tensor;
byte_swap_tensor_content(tensor);
}
}
}
}

public static void byte_swap_tensor_content(TensorProto tensor)
{
if (byte_swappable.Contains((int)tensor.Dtype))
{
var tshape = tensor.TensorShape.Dim;
var tensor_bytes = tensor.TensorContent;
if (tensor_bytes is not null && !tensor_bytes.IsEmpty)
{
long tensor_size = 1;
foreach (var sz in tshape)
{
tensor_size *= sz.Size;
}

var chunksize = tensor_bytes.Length / tensor_size;
List<byte> reversed_bytes = new();
for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize)
{
var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse();
reversed_bytes.AddRange(current);
}
tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray());
}
}
}
}

+ 53
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.ModelSaving;

namespace Tensorflow.Training.Saving.SavedModel
{
/// <summary>
/// A context for building a graph of SavedModel.
/// </summary>
public static class SaveContext
{
// TODO: make it thead safe.
private static bool _in_save_context = false;
private static SaveOptions _save_options = null;

public static bool in_save_context() => _in_save_context;
public static SaveOptions get_save_options()
{
if (!in_save_context())
{
throw new ValueError("Not in a SaveContext.");
}
return _save_options;
}
public static SaveContextHandler save_context(SaveOptions options)
{
return new SaveContextHandler(options);
}
public class SaveContextHandler: IDisposable
{
private bool _old_in_save_context;
private SaveOptions _old_save_options;
public SaveContextHandler(SaveOptions options)
{
if (SaveContext.in_save_context())
{
throw new ValueError("Already in a SaveContext.");
}
_old_in_save_context = SaveContext._in_save_context;
SaveContext._in_save_context = true;
_old_save_options = SaveContext._save_options;
SaveContext._save_options = options;
}
public void Dispose()
{
SaveContext._in_save_context = _old_in_save_context;
SaveContext._save_options = _old_save_options;
}
}
}
}

+ 107
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -0,0 +1,107 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow;

public static class SignatureSerializationUtils
{
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature";
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures";
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5;
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
Debug.Assert(func is ConcreteFunction);
// TODO: assert the `func.structured_outputs` and arg_keywords.
signature_map._add_signature(name, (ConcreteFunction)func);
}

return signature_map;
}

public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view)
{
var children = graph_view.list_children(graph_view.Root);
List<Trackable> possible_signatures = new();
foreach (var item in children)
{
var name = item.Name;
var child = item.Refer;
if(child is not (Function or ConcreteFunction))
{
continue;
}
if(name == DEFAULT_SIGNATURE_ATTR)
{
Debug.Assert(child is ConcreteFunction);
return (ConcreteFunction)child;
}
ConcreteFunction concrete = get_signature(child);
if(concrete is not null && valid_signature(concrete))
{
possible_signatures.Add(concrete);
}
}

if(possible_signatures.Count == 1)
{
var signature = get_signature(possible_signatures[0]);
if(signature is not null && valid_signature(signature))
{
return signature;
}
}
return null;
}

private static ConcreteFunction get_signature(Trackable function)
{
// TODO: implement it.
return null;
}

private static bool valid_signature(ConcreteFunction concreate_function)
{
// TODO: implement it.
return false;
}
}

public class SignatureMap: Trackable
{
private Dictionary<string, Trackable> _signatures;

public SignatureMap()
{
_signatures = new();
}

public void _add_signature(string name, ConcreteFunction concrete_function)
{
_signatures[name] = concrete_function;
}
public void _add_signature(string name, Function concrete_function)
{
_signatures[name] = concrete_function;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if (save_type != SaveType.SAVEDMODEL)
{
return new Dictionary<string, Trackable>();
}

return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value);
}
}

+ 57
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs View File

@@ -0,0 +1,57 @@
using System.IO;
using System.Security.Cryptography.X509Certificates;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow;

public static partial class SavedModelUtils
{
/// <summary>
/// Return variables sub-directory, or create one if it doesn't exist.
/// </summary>
/// <returns></returns>
public static string get_or_create_variables_dir(string export_dir)
{
var variables_dir = get_variables_dir(export_dir);
Directory.CreateDirectory(variables_dir);
return variables_dir;
}

/// <summary>
/// Return variables sub-directory in the SavedModel.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_variables_dir(string export_dir)
{
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY));
}

public static string get_variables_path(string export_dir)
{
return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME));
}

/// <summary>
/// Return assets sub-directory, or create one if it doesn't exist.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_or_create_assets_dir(string export_dir)
{
var assets_destination_dir = get_assets_dir(export_dir);
Directory.CreateDirectory(assets_destination_dir);
return assets_destination_dir;
}

/// <summary>
/// Return path to asset directory in the SavedModel.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_assets_dir(string export_dir)
{
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY));
}
}

+ 254
- 2
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -16,12 +16,38 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class saveable_object_util
/// <summary>
/// A SaveableObject that defines `Trackable` checkpointing steps.
/// </summary>
public class TrackableSaveable : MySaveableObject
{
private string _prefix;
private IEnumerable<string> _local_names;
private Trackable _trackable;
private bool _call_with_mapped_captures;
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict.
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names,
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name)
{
_prefix = prefix;
_trackable = obj;
_local_names = local_names;
_call_with_mapped_captures = call_with_mapped_captures;
}

// TODO: complete this class.
}
public static class saveable_object_util
{
/// <summary>
/// Returns the variables and names that will be used for a Saver.
@@ -52,7 +78,7 @@ namespace Tensorflow
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
@@ -74,6 +100,73 @@ namespace Tensorflow
}
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name)
{
// The `op` maybe `Variable` or `Trackable`.
if (obj is BaseResourceVariable)
{
var variable = obj as BaseResourceVariable;
if (variable.InGraphMode)
{
yield return new ResourceVariableSaveable(variable.GraphElement, "", name);
}
else
{
yield return new ResourceVariableSaveable(variable, "", name);
}
}
else
{
foreach(var pair in saveable_objects_from_trackable(obj))
{
var attr = pair.Key;
var factory = pair.Value;
string full_name;
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY)
{
full_name = name;
}
else
{
full_name = name + "_" + attr;
}
if(factory.DataType == typeof(ResourceVariable))
{
var variable = factory.GetValueA();
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
{
yield return op;
}
}
else
{
var variable = factory.GetValueB();
foreach (var op in saveable_objects_for_op(variable, variable.name))
{
yield return op;
}
}
}
}
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name)
{
yield return obj;
}

public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true)
{
op_list = op_list.OrderBy(x => x.Name).ToArray();
@@ -121,5 +214,164 @@ namespace Tensorflow

return names_to_saveables;
}

public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
{
// skip the process of type `PythonState`

if (trackable_has_serialize_to_tensor(obj))
{
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME;
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors();

List<SaveSpec> specs = new();
List<string> local_names = new();
string prefix = SaveableCompat.get_saveable_name(obj) ?? "";
foreach(var pair in tensor_dict)
{
var tensor_name = pair.Key;
var maybe_tensor = pair.Value;
local_names.Add(tensor_name);
string spec_name = name + TrackableUtils.escape_local_name(tensor_name);

IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.DataType == typeof(Tensor))
{
internal_dict= new Dictionary<string, Tensor>();
internal_dict[""] = maybe_tensor.GetValueA();
}
else
{
internal_dict = maybe_tensor.GetValueB();
}

foreach(var item in internal_dict)
{
specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
}
}
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return res;
}
else
{
return obj.gather_saveables_for_checkpoint();
}
}

public static bool trackable_has_serialize_to_tensor(Trackable obj)
{
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable);
}

internal static string convert_to_string(string x)
{
return tf.compat.as_str(x);
}

/// <summary>
/// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary>
/// <param name="saveables"></param>
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach (var saveable in saveables)
{
foreach (var spec in saveable.specs)
{
// skip the check that if `spec` is callable.
var name = convert_to_string(spec.name);
var slice_spec = convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec))
{
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor;
}
else
{
tensor_dict[name] = spec.tensor;
}
}
}
return tensor_dict;
}

/// <summary>
/// Generates `Trackable._restore_from_tensors` from SaveableObjects.
/// </summary>
/// <returns></returns>
public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables)
{
return (restored_tensors) =>
{
Dictionary<string, Operation> restored_ops = new();

foreach(var saveable in saveables)
{
List<Tensor> saveable_restored_tensors = new();
foreach(var spec in saveable.specs)
{
var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name));
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);

var maybe_tensor = restored_tensors[name];
IDictionary<string, Tensor> dict;
if(maybe_tensor.DataType == typeof(Tensor))
{
dict = new Dictionary<string, Tensor>();
dict[""] = maybe_tensor.GetValueA();
}
else
{
dict = maybe_tensor.GetValueB();
}
saveable_restored_tensors.Add(dict[slice_spec]);
}
restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null);
}
return restored_ops;
};
}
}

public class SaveableCompatibilityConverter: Trackable
{
private object _obj;
private IList<MySaveableObject> _saveables;
public SaveableCompatibilityConverter(object obj, IList<MySaveableObject> saveables)
{
_obj= obj;
_saveables= saveables;
}

public object Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
return saveable_object_util.saveable_object_to_tensor_dict(_saveables);
}

/// <summary>
/// Returns the restore ops defined in the Saveables.
/// </summary>
/// <param name="restored_tensors"></param>
/// <returns></returns>
public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{
List<string> expected_keys = new();
foreach(var saveable in _saveables)
{
expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name))));
}
if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys))
{
throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." +
$"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}");
}
return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors);
}
}
}

+ 186
- 6
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -14,13 +14,63 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
public abstract class Trackable
public abstract class Trackable: IWithTrackable
{
/// <summary>
/// Corresponding to tensorflow/python/trackable/constants.py
/// </summary>
public static class Constants
{
public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH";
public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE";
public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON";
}
protected int _self_update_uid;
protected IDictionary<string, Trackable> _unconditional_dependency_names;

protected IList<TrackableReference> _unconditional_checkpoint_dependencies;

protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
private bool _manual_tracking = true;

private static Trackable _none = new AutoTrackable();
/// <summary>
/// This is a trick for that CSharp does not allow the key of `Dictionary` to be null.
/// The `None` can be any object that inherits `Trackable`.
/// This Property is supposed to be used only internal.
/// </summary>
public static Trackable None
{
get
{
return _none;
}
}
public Trackable GetTrackable()
{
return this;
}
public virtual string ObjectIdentifier
{
get => "_generic_user_object";
}
public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; }
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }

/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
@@ -47,9 +97,13 @@ namespace Tensorflow.Train
// assign again. It will add this variable to our dependencies, and if there
// is a non-trivial restoration queued, it will handle that. This also
// handles slot variables.
if (!args.Overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: args.Name,
overwrite: args.Overwrite);
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}
else
return new_variable;
}
@@ -73,10 +127,136 @@ namespace Tensorflow.Train
/// <summary>
/// Initialize dependency management.
/// </summary>
protected void _maybe_initialize_trackable()
public void _maybe_initialize_trackable()
{
// _self_unconditional_checkpoint_dependencies = []
if(_unconditional_checkpoint_dependencies is not null)
{
return;
}
_self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>();
}

public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache)
{
_maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
}

public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false)
{
_maybe_initialize_trackable();
if (!_manual_tracking) return trackable;
var new_reference = new TrackableReference(name, trackable);
var current_object = _lookup_dependency(name);

if(current_object is null)
{
_unconditional_checkpoint_dependencies.Add(new_reference);
_handle_deferred_dependencies(name, trackable);
}
_unconditional_dependency_names[name] = trackable;
return trackable;
}

/// <summary>
/// Pop and load any deferred checkpoint restores into `trackable`.
/// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for
/// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are
/// considered fulfilled and so are deleted.
/// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations.
/// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves,
/// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context
/// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs).
/// </summary>
/// <param name="name"></param>
/// <param name="trackable"></param>
public virtual void _handle_deferred_dependencies(string name, Trackable trackable)
{
//_maybe_initialize_trackable();
//trackable._maybe_initialize_trackable();
// TODO: complete the implementation.
}

public virtual Trackable? _lookup_dependency(string name)
{
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency;
else return null;
}

public static Trackable convert_to_trackable(object obj, object? parent = null)
{
if (obj is Trackable)
{
return (Trackable)obj;
}
else
{
throw new NotImplementedException();
}
}

public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children)
{
return new Dictionary<string, Trackable>();
}

public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(
SaveOptions? save_options)
{
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>());
}

public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map,
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null)
{
var (self_object_map, self_tensor_map) = map_resources(options);
foreach (var pair in self_object_map)
{
object_map.Add(pair);
}
foreach (var pair in self_tensor_map)
{
tensor_map.Add(pair);
}

return self_tensor_map.Keys.ToList();
}

public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
throw new NotImplementedException();
}
else
{
return _self_saveable_object_factories;
}
}

/// <summary>
/// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors`
/// if you are defining a custom resource or variable with custom ops.
/// Otherwise, please store the state of your trackable in `tf.Variable` objects
/// and add them to Trackable object hierarchy using `setattr` (for subclasses
/// of `AutoTrackable`) or overriding the `_trackable_children` method.
/// </summary>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
throw new NotImplementedException();
}

public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{
throw new NotImplementedException();
}
}

public record class TrackableReference(string Name, Trackable Refer);
}

+ 172
- 0
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -0,0 +1,172 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;

namespace Tensorflow.Training;

public static class TrackableUtils
{
public class CyclicDependencyError: System.Exception
{
public IDictionary<int, IEnumerable<int>> LeftOverDependencyMap { get; }
public CyclicDependencyError(IDictionary<int, IEnumerable<int>> leftover_dependency_map): base()
{
LeftOverDependencyMap = leftover_dependency_map;
}
public CyclicDependencyError(IDictionary<int, List<int>> leftover_dependency_map): base()
{
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable());
}
}
private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name)));
}

public static string escape_local_name(string name)
{
return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S");
}
public static string checkpoint_key(string object_path, string local_name)
{
var key_suffix = escape_local_name(local_name);
if (local_name == SERIALIZE_TO_TENSORS_NAME)
{
key_suffix = "";
}

return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}";
}

/// <summary>
/// Topologically sorts the keys of a map so that dependencies appear first.
/// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
/// </summary>
/// <param name="dependency_map"></param>
/// <exception cref="ValueError"></exception>
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map)
{
Dictionary<int, HashSet<int>> reverse_dependency_map = new();
foreach (var pair in dependency_map)
{
foreach (var dep in pair.Value)
{
if (reverse_dependency_map.ContainsKey(dep))
{
reverse_dependency_map[dep].Add(pair.Key);
}
else
{
reverse_dependency_map[dep] = new HashSet<int>();
reverse_dependency_map[dep].Add(pair.Key);
}
}
}
// Validate that all values in the dependency map are also keys.
var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys);
if (unknown_keys.Count() > 0)
{
throw new ValueError(
$"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}");
}
// Generate the list sorted by objects without dependencies -> dependencies.
// The returned list will reverse this.
List<int> reversed_dependency_arr = new();

Queue<int> to_visit = new();
foreach (var x in dependency_map.Keys)
{
if (!reverse_dependency_map.ContainsKey(x))
{
to_visit.Enqueue(x);
}
}

while (to_visit.Count > 0)
{
var x = to_visit.Dequeue();
reversed_dependency_arr.Add(x);
foreach (var dep in dependency_map[x].Distinct())
{
var edges = reverse_dependency_map[dep];
edges.Remove(x);
if (edges.Count == 0)
{
to_visit.Enqueue(dep);
if (!reverse_dependency_map.Remove(dep))
{
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map");
}
}
}
}

if (reverse_dependency_map.Count > 0)
{
Dictionary<int, List<int>> leftover_dependency_map = new();
foreach (var pair in reverse_dependency_map)
{
foreach (var x in pair.Value)
{
if (leftover_dependency_map.ContainsKey(x))
{
leftover_dependency_map[x].Add(pair.Key);
}
else
{
leftover_dependency_map[x] = new List<int>() { pair.Key };
}
}
}

throw new CyclicDependencyError(leftover_dependency_map);
}

reversed_dependency_arr.Reverse();
return reversed_dependency_arr;
}

public static string pretty_print_node_path(IEnumerable<TrackableReference> paths)
{
if (paths.Count() == 0)
{
return "root object";
}
else
{
return $"root.{string.Join(".", paths.Select(x => x.Name))}";
}
}

/// <summary>
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key.
/// </summary>
/// <param name="key"></param>
/// <param name="prefix"></param>
/// <returns></returns>
public static string extract_local_name(string key, string? prefix = null)
{
if(prefix is null)
{
prefix = "";
}
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix;
try
{
return key.Substring(key.IndexOf(search_key) + search_key.Length);
}
catch(ArgumentOutOfRangeException)
{
return key;
}
}
}

+ 370
- 0
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -0,0 +1,370 @@
using Google.Protobuf;
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO.Compression;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Keras;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;

namespace Tensorflow.Training
{
public class NoDependency
{
public Trackable Value { get; set; }
public NoDependency(Trackable value)
{
Value = value;
}
}

public abstract class TrackableDataStructure : Trackable
{
private bool _self_trainable;
private List<IVariableV1> _self_extra_variables;

public TrackableDataStructure()
{
_self_trainable = true;
_self_extra_variables = new List<IVariableV1>();
}

public abstract IEnumerable<Trackable> Values { get; }
public bool Trainable { get => _self_trainable; set => _self_trainable = value; }
public IEnumerable<ILayer> Layers
{
get
{
List<ILayer> collected = new();
foreach(var obj in Values)
{
if(obj is ILayer)
{
collected.Add((ILayer)obj);
}
else if(obj is TrackableDataStructure)
{
collected.AddRange((obj as TrackableDataStructure).Layers);
}
}
return collected;
}
}
public IEnumerable<IVariableV1> TrainableWeights
{
get
{
if (!_self_trainable)
{
return new List<IVariableV1>();
}
List<IVariableV1> trainable_variables = new();
foreach (var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables);
}
}
foreach(var v in _self_extra_variables)
{
if (v.Trainable)
{
trainable_variables.Add(v);
}
}
return trainable_variables;
}
}
public IEnumerable<IVariableV1> NonTrainableWeights
{
get
{
var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList();
var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList();
List<IVariableV1> non_trainable_variables = new();
foreach(var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables);
}
}

if (!_self_trainable)
{
// Return order is all trainable vars, then all non-trainable vars.
List<IVariableV1> trainable_variables = new();
foreach(var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables);
}
}
return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables);
}
else
{
return non_trainable_variables.concat(non_trainable_extra_variables);
}
}
}
public IEnumerable<IVariableV1> Weights => TrainableWeights.Concat(NonTrainableWeights);
public IEnumerable<IVariableV1> TrainableVariables => TrainableWeights;
public IEnumerable<IVariableV1> NonTrainableVariables => NonTrainableWeights;
public IEnumerable<IVariableV1> Variables => Weights;

// TODO: `losses` property.

/// <summary>
/// Add a dependency on `value`.
/// </summary>
/// <param name="value"></param>
/// <param name="name"></param>
protected virtual Trackable _track_value(Trackable value, string name)
{
value = sticky_attribute_assignment(this, name, value);
if(value is IVariableV1)
{
_self_extra_variables.Add(value as IVariableV1);
}
// skip the left process (need to be done in the future).
return value;
}

public static Trackable wrap_or_unwrap(NoDependency value)
{
return value.Value;
}

public static Trackable wrap_or_unwrap(Trackable value)
{
return value;
}

public static Trackable wrap_or_unwrap(IList<Trackable> value)
{
return new ListWrapper(value);
}

public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value)
{
return new ListWrapper(value.ToList());
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value)
{
value = wrap_or_unwrap(value);
trackable._track_trackable(value, name, true);
return value;
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
}
}

public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable
{
private IList<Trackable> _storage;
private bool _non_append_mutation_value;
private bool _external_modification_value;
private IList<Trackable> _last_wrapped_list_snapshot;
/// <summary>
///
/// </summary>
/// <param name="wrapped_list">The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be
/// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param>
public ListWrapper(IList<Trackable> wrapped_list)
{
_storage = wrapped_list;
_non_append_mutation_value = _external_modification_value = false;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

protected bool NonAppendMuation {
get => _non_append_mutation_value;
set
{
// TODO: deal with `attribute_sentinel`.
_non_append_mutation_value = value;
}
}

protected bool ExternalModification
{
get => _external_modification_value;
set
{
// TODO: deal with `attribute_sentinel`.
_external_modification_value = value;
}
}

public override IEnumerable<Trackable> Values => this;
public bool IsReadOnly { get => _storage.IsReadOnly; }

/// <summary>
/// Checks for any changes to the wrapped list not through the wrapper.
/// </summary>
private void check_external_modification()
{
if (_external_modification_value || _non_append_mutation_value) return;
if (!_storage.SequenceEqual(_last_wrapped_list_snapshot))
{
_external_modification_value = true;
}
}

private void update_snapshot()
{
// TODO: deal with `attribute_sentinel`.
if (_external_modification_value || _non_append_mutation_value) return;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
check_external_modification();
if (_non_append_mutation_value)
{
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" +
$", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." +
$"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.");
}
if (_external_modification_value)
{
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " +
$"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " +
$"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored.");
}
var children = base._trackable_children(save_type, cache);

if(save_type == SaveType.SAVEDMODEL)
{
children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value);
}

return children;
}

private bool has_mutation_or_trackable()
{
return _non_append_mutation_value;
}

/// <summary>
/// Allows storage of non-trackable objects.
/// </summary>
/// <param name="value"></param>
/// <param name="name"></param>
/// <returns></returns>
protected override Trackable _track_value(Trackable value, string name)
{
try
{
base._track_value(value, name);
}
catch(ValueError ex)
{
value = sticky_attribute_assignment(this, name, value);
}
return value;
}

public object Clone()
{
var res = new ListWrapper(_storage.Select(x => x).ToList());
res.NonAppendMuation= _non_append_mutation_value;
res.ExternalModification = _external_modification_value;
return res;
}

public Trackable this[int index] {
get => _storage[index];
set
{
// skip the process of `Slice`, maybe support it in the future.
_non_append_mutation_value = true;
_storage[index] = _track_value(value, _name_element(index));

update_snapshot();
}
}

public int IndexOf(Trackable item) => _storage.IndexOf(item);

public void Insert(int index, Trackable item)
{
check_external_modification();
_non_append_mutation_value = true;
_storage.Insert(index, item);
update_snapshot();
}

public void RemoveAt(int index)
{
check_external_modification();
if (has_mutation_or_trackable())
{
_non_append_mutation_value = true;
}
_storage.RemoveAt(index);
update_snapshot();
}

public int Count { get => _storage.Count; }

public void Add(Trackable item)
{
check_external_modification();
_storage.Add(item);
update_snapshot();
}

public void Clear() => _storage.Clear();

public bool Contains(Trackable item) => _storage.Contains(item);

public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex);

public bool Remove(Trackable item)
{
check_external_modification();
if (has_mutation_or_trackable())
{
_non_append_mutation_value = true;
}
var res = _storage.Remove(item);
update_snapshot();
return res;
}

public IEnumerator<Trackable> GetEnumerator() => _storage.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator();

protected string _name_element(int index) => $"{index}";
}
}

+ 71
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -2,14 +2,20 @@
using System;
using Tensorflow.Eager;
using Tensorflow.Variables;
using Tensorflow.Train;
using static Tensorflow.Binding;
using System.Collections.Generic;
using Tensorflow.ModelSaving;
using System.Diagnostics;
using Tensorflow.Checkpoint;

namespace Tensorflow
{
public class BaseResourceVariable : DisposableObject
public class BaseResourceVariable : DisposableTrackableObject
{
protected string _name;
public virtual string Name => _handle_name;
public virtual string SharedName => _name;
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;
protected string _handle_name;
@@ -19,9 +25,10 @@ namespace Tensorflow
public string UniqueId => _unique_id;

protected bool _in_graph_mode;
internal bool InGraphMode => _in_graph_mode;

protected bool _trainable;
public bool trainable => _trainable;
public bool Trainable => _trainable;

protected Tensor _initial_value;

@@ -46,6 +53,7 @@ namespace Tensorflow
public Graph Graph => handle.graph;
public string Device => handle.Device;
EagerResourceDeleter eager_resource_deleter;
public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None;

public BaseResourceVariable()
{
@@ -73,6 +81,11 @@ namespace Tensorflow
_handle = handle.EagerTensorHandle.DangerousGetHandle();
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
}
else if(handle is null)
{
// TODO: fix this dangerous change.
_handle = IntPtr.Zero;
}
else
{
_handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
@@ -165,7 +178,7 @@ namespace Tensorflow
/// </summary>
void variable_accessed(BaseResourceVariable variable)
{
if (variable.trainable)
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
tape.VariableAccessed(variable as ResourceVariable);
@@ -243,5 +256,60 @@ namespace Tensorflow
else
return value();
}

public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options)
{
BaseResourceVariable new_variable;
if (save_options.experimental_variable_policy.save_variable_devices())
{
tf.device(this.Device);
Debug.Assert(this is ResourceVariable);
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
}
else
{
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
}
Dictionary<Trackable, Trackable> obj_map = new();
Dictionary<Tensor, Tensor> resource_map = new();
obj_map[this] = new_variable;
resource_map[this.handle] = new_variable.handle;
return (obj_map, resource_map);
}

/// <summary>
/// Writes additional information of the variable into the SavedObject proto.
/// ubclasses of ResourceVariables could choose to override this method to
/// customize extra information to provide when saving a SavedModel.
/// </summary>
/// <param name="proto"></param>
/// <param name="options"></param>
public virtual void write_object_proto(SavedObject proto, SaveOptions options)
{
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
}

public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
return res;
}

public Tensor is_initialized(string name = null)
{
return gen_resource_variable_ops.var_is_initialized_op(this.handle, name);
}

public Tensor read_value_no_copy()
{
Tensor value = null;
tf_with(ops.name_scope("Read"), _ =>
{
// TODO: `no_copy = true`.
value = _read_variable_op();
});
return array_ops.identity(value);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -46,6 +46,7 @@ namespace Tensorflow
Graph Graph { get; }
TF_DataType dtype { get; }
Shape shape { get; }
bool Trainable { get; }
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null);


+ 3
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -20,11 +20,12 @@ using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow.Train;

namespace Tensorflow
{
[Obsolete]
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable>
public partial class RefVariable: Trackable, IVariableV1, IProtoBuf<VariableDef, RefVariable>
{
protected string _name;
public string UniqueId => _name;
@@ -56,6 +57,7 @@ namespace Tensorflow
public string Name => _variable.name;

public Tensor eval() => _variable;
public bool Trainable => _trainable;

public RefVariable(object initial_value = null,
bool trainable = true,


+ 3
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -17,7 +17,9 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using Tensorflow.Checkpoint;
using Tensorflow.NumPy;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -39,6 +41,7 @@ namespace Tensorflow
VariableAggregation aggregation = VariableAggregation.None,
Shape shape = null)
{
Aggregation = aggregation;
if (variable_def != null)
{
if (initial_value != null)


+ 70
- 0
src/TensorFlowNET.Core/Variables/UninitializedVariable.cs View File

@@ -0,0 +1,70 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Gradients;
using static Tensorflow.Binding;

namespace Tensorflow.Variables
{
/// <summary>
/// A variable with no initializer.
/// </summary>
public sealed class UninitializedVariable: BaseResourceVariable
{
// TODO: complete the arg list.
public UninitializedVariable(
bool trainable = true,
string caching_device = "",
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
VariableAggregation aggregation = VariableAggregation.None,
Shape shape = null,
Tensor extra_handle_data = null)
{
string unique_id = "";
string handle_name = "";
tf_with(ops.init_scope(), (x) =>
{
_in_graph_mode = !tf.Context.executing_eagerly();
tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name =>
{
handle_name = ops.name_from_scope_name(name);
string? shared_name;
if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else
{
unique_id = $"{handle_name}-{ops.uid()}";
shared_name = null;
}
var handle = resource_variable_ops.variable_handle_from_shape_and_dtype(
shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data);
// skip the assignment of `handle._parent_trackable` because of lack of API.
// skip the assignment of `handle._name` and `handle._unique_id` because of accessability.

if (_in_graph_mode)
{
tf_with(ops.name_scope("Read"), _ =>
{
tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
// _maybe_set_handle_data(dtype, handle, value)
_graph_element = value;
});
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
}
else
{
_graph_element = null;
}
});
});
_shape = shape;
_dtype = dtype;
base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -566,5 +566,23 @@ namespace Tensorflow
else
throw new NotImplementedException("");
}

public static bool inside_function()
{
return get_default_graph().building_function;
}

public static void dismantle_graph(Graph graph)
{
}

public class NullContextManager: IDisposable
{
public void Dispose()
{
}
}
}
}

+ 17
- 14
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
public ModelConfig get_config()
public override IKerasConfig get_config()
{
return get_network_config();
}
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
{
Name = name
};
var node_conversion_map = new Dictionary<string, int>();
foreach (var layer in _self_tracked_trackables)
{
@@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine
}

var layer_configs = new List<LayerConfig>();
foreach (var layer in _self_tracked_trackables)
using (SharedObjectSavingScope.Enter())
{
var filtered_inbound_nodes = new List<NodeConfig>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
foreach (var layer in _self_tracked_trackables)
{
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
var filtered_inbound_nodes = new List<NodeConfig>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{
var node_data = node.serialize(_make_node_key, node_conversion_map);
filtered_inbound_nodes.append(node_data);
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
{
var node_data = node.serialize(_make_node_key, node_conversion_map);
filtered_inbound_nodes.append(node_data);
}
}
}

var layer_config = generic_utils.serialize_keras_object(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);
var layer_config = generic_utils.serialize_layer_to_config(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);
}
}
config.Layers = layer_configs;



+ 50
- 0
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -2,7 +2,9 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
@@ -20,6 +22,30 @@ namespace Tensorflow.Keras.Engine

Dictionary<long, int> tensor_usage_count;

/// <summary>
/// Dictionary of layer dependencies to be included in the checkpoint.
/// </summary>
public IDictionary<string, ILayer> LayerCheckpointDependencies
{
get
{
int weight_layer_index = 0;
Dictionary<string, ILayer> dependencies = new();
for(int i = 0; i < Layers.Count; i++)
{
var layer = Layers[i];
var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList();
if(weights.Count > 0)
{
dependencies[$"layer_with_weights-{weight_layer_index}"] = layer;
weight_layer_index++;
}
dependencies[$"layer-{i}"] = layer;
}
return dependencies;
}
}

public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs
{
@@ -44,6 +70,7 @@ namespace Tensorflow.Keras.Engine
this.inputs = inputs;
this.outputs = outputs;
built = true;
_buildInputShape = inputs.shape;

if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs);
@@ -325,5 +352,28 @@ namespace Tensorflow.Keras.Engine

return output_tensors;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value);
}

protected override void _init_set_name(string name, bool zero_based = true)
{
if (string.IsNullOrEmpty(name))
{
string class_name = GetType().Name;
if (this.GetType() == typeof(Functional))
{
class_name = "Model";
}
this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based);
}
else
{
this.name = name;
}
}
}
}

+ 32
- 0
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -0,0 +1,32 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Keras.Engine;

public abstract partial class Layer
{
public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);

public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;

public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
IDictionary<string, Trackable> children;
if (save_type == SaveType.SAVEDMODEL)
{
Debug.Assert(cache is not null);
children = TrackableSavedModelSaver.trackable_children(cache);
}
else
{
children = new Dictionary<string, Trackable>();
}

return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value);
}
}

+ 32
- 4
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine
public bool Built => built;
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;
public bool AutoCast => args.Autocast;
public IRegularizer ActivityRegularizer => args.ActivityRegularizer;

/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
@@ -59,6 +61,7 @@ namespace Tensorflow.Keras.Engine
/// Provides information about which inputs are compatible with the layer.
/// </summary>
protected InputSpec inputSpec;
public InputSpec InputSpec => inputSpec;
bool dynamic = true;
public bool SupportsMasking { get; set; }
protected List<IVariableV1> _trainable_weights;
@@ -77,6 +80,8 @@ namespace Tensorflow.Keras.Engine
protected bool computePreviousMask;
protected List<Operation> updates;
public Shape BatchInputShape => args.BatchInputShape;
protected TensorShapeConfig _buildInputShape = null;
public TensorShapeConfig BuildInputShape => _buildInputShape;

List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes;
@@ -86,9 +91,29 @@ namespace Tensorflow.Keras.Engine

ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;
public Tensor[] input
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].input_tensors;
}
return null;
}
}
public Dictionary<int, List<INode>> NodesByDepth { get; set; }
public Shape OutputShape => inboundNodes[0].Outputs.shape;
public Shape OutputShape
{
get
{
if(inboundNodes is not null && inboundNodes.Count > 0)
{
return inboundNodes[0].Outputs.shape;
}
return null;
}
}
protected List<ILayer> _self_tracked_trackables;

public Layer(LayerArgs args)
@@ -162,7 +187,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="inputs"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <param name="training"></param>
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
@@ -201,6 +226,7 @@ namespace Tensorflow.Keras.Engine

public virtual void build(Shape input_shape)
{
_buildInputShape = input_shape;
built = true;
}

@@ -286,7 +312,9 @@ namespace Tensorflow.Keras.Engine
}
}

public virtual LayerArgs get_config()
public List<IVariableV1> Variables => weights;

public virtual IKerasConfig get_config()
=> args;
}
}

+ 5
- 0
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -33,6 +33,11 @@ namespace Tensorflow.Keras.Engine
int workers = 1,
bool use_multiprocessing = false)
{
if (x.dims[0] != y.dims[0])
{
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];


+ 17
- 2
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -1,5 +1,8 @@
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;

namespace Tensorflow.Keras.Engine
@@ -18,9 +21,21 @@ namespace Tensorflow.Keras.Engine
bool overwrite = true,
bool include_optimizer = true,
string save_format = "tf",
SaveOptions options = null)
SaveOptions? options = null,
ConcreteFunction? signatures = null,
bool save_traces = true)
{
saver.save(this, filepath);
if (save_format != "pb")
{
saver.save(this, filepath);
}
else
{
using (SharedObjectSavingScope.Enter())
{
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
}
}
}
}
}

+ 19
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -4,6 +4,8 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -34,6 +36,13 @@ namespace Tensorflow.Keras.Engine
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;
DataHandler data_handler;
public OptimizerV2 Optimizer
{
get => optimizer;
set => optimizer = value;
}

public Model(ModelArgs args)
: base(args)
@@ -101,5 +110,15 @@ namespace Tensorflow.Keras.Engine
return variables;
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type == SaveType.SAVEDMODEL)
{
//TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`.
}
var children = base._trackable_children(save_type, cache);
return children;
}
}
}

+ 1
- 0
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers {
{
throw new ValueError("Alpha must be a number greater than 0.");
}
_buildInputShape = input_shape;
built = true;
}



+ 1
- 0
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers {
}
public override void build(Shape input_shape)
{
_buildInputShape = input_shape;
built = true;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 5
- 4
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers {
// SELU has no arguments
}
public override void build(Shape input_shape) {
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
built = true;
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
_buildInputShape = input_shape;
built = true;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Attention/Attention.cs View File

@@ -4,6 +4,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
@@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers
return scores;
}

public override LayerArgs get_config() => this.args;
public override IKerasConfig get_config() => this.args;
//var config = new Dictionary<object, object> {
// {
// "use_scale",


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save