Browse Source

Add pb model save (#976)

* Add check for dims of x and y in model.fit.

* Init the serialization of keras pb model.

* Add more facilities to the saved model framework.

* Add ListWrapper and ITrackable, and revise implmentations.

* Add serialized attributes.

* Implement layer serializations.

* Add lacked implementations (mainly MultiDeviceSaver).

* Support autograph.to_graph under graph mode.

* Add more implementations to the pb model save.

* Add more implementations to the keras part of pb model save.

* Refine some code after merge.

* Add two simple sequential test case of pb model save.

* Implement serializing attributes other keras arg definitions.

* Add alexnet pb save test.

* Check and refine the code.

---------

Co-authored-by: AsakusaRinne <AsakusaRinne@gmail.com>
tags/v0.100.4-load-saved-model
Haiping GitHub 2 years ago
parent
commit
197224fd74
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 3041 additions and 169 deletions
  1. +22
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  2. +152
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  4. +64
    -0
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  5. +255
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  6. +223
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  7. +16
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs
  8. +82
    -0
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  9. +195
    -0
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  10. +540
    -0
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  11. +69
    -1
      src/TensorFlowNET.Core/DisposableObject.cs
  12. +31
    -0
      src/TensorFlowNET.Core/Eager/execute.cs
  13. +14
    -0
      src/TensorFlowNET.Core/Exceptions/AssertionError.cs
  14. +85
    -1
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  15. +2
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  16. +9
    -2
      src/TensorFlowNET.Core/Functions/Function.cs
  17. +32
    -14
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  18. +7
    -4
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs
  19. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs
  20. +7
    -4
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
  21. +4
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs
  22. +4
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs
  23. +19
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs
  24. +25
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  25. +37
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
  26. +41
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
  27. +34
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
  28. +13
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs
  29. +15
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
  30. +0
    -16
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs
  31. +0
    -16
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs
  32. +0
    -10
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs
  33. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  34. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  35. +17
    -14
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  36. +0
    -6
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
  37. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
  38. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
  39. +18
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs
  40. +13
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs
  41. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
  42. +8
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
  43. +8
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs
  44. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
  45. +12
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs
  46. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
  47. +10
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
  48. +7
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs
  49. +0
    -8
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs
  50. +18
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs
  51. +18
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs
  52. +12
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs
  53. +5
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
  54. +8
    -4
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs
  55. +5
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
  56. +7
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs
  57. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs
  58. +2
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  59. +7
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  60. +12
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  61. +50
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
  62. +48
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  63. +73
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  64. +67
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  65. +30
    -1
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  66. +4
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  67. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
  68. +15
    -0
      src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs
  69. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
  70. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  71. +5
    -2
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  72. +35
    -0
      src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs
  73. +21
    -0
      src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
  74. +43
    -1
      src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
  75. +10
    -1
      src/TensorFlowNET.Core/NumPy/Axis.cs
  76. +3
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  77. +10
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  78. +9
    -1
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  79. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  80. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  81. +6
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  82. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  83. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  84. +11
    -0
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  85. +13
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  86. +5
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  87. +7
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  88. +55
    -4
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  89. +32
    -0
      src/TensorFlowNET.Core/Operations/io_ops.cs
  90. +60
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  91. +9
    -1
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  92. +16
    -0
      src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs
  93. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  94. +18
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  95. +67
    -2
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  96. +12
    -0
      src/TensorFlowNET.Core/Training/IWithTrackable.cs
  97. +9
    -0
      src/TensorFlowNET.Core/Training/LayerUtils.cs
  98. +6
    -1
      src/TensorFlowNET.Core/Training/Optimizer.cs
  99. +28
    -0
      src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs
  100. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs

+ 22
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Text;

namespace Tensorflow namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
@@ -23,6 +25,26 @@ namespace Tensorflow
public class CompatApi public class CompatApi
{ {
public CompatV1Api v1 { get; } = new CompatV1Api(); public CompatV1Api v1 { get; } = new CompatV1Api();

internal string as_text(string bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return bytes_or_text;
}
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return encoding.GetString(bytes_or_text);
}
internal string as_str(string bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
} }


public bool executing_eagerly() public bool executing_eagerly()


+ 152
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -0,0 +1,152 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint;

public static class CheckPointUtils
{
private static string _ESCAPE_CHAR = ".";
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
}

public static
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
{
var non_slot_objects = trackable_objects.ToList();
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables = new();
foreach (var trackable in non_slot_objects)
{
if (trackable is not Optimizer)
{
continue;
}

var optim = (Optimizer)trackable;
var slot_names = optim.get_slot_names();
foreach (var slot_name in slot_names)
{
for (int original_variable_node_id = 0;
original_variable_node_id < non_slot_objects.Count;
original_variable_node_id++)
{
var original_variable = non_slot_objects[original_variable_node_id];
IVariableV1 slot_variable;
if (original_variable is not IVariableV1)
{
slot_variable = null;
}
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
if(slot_variable is null) continue;

// There're some problems about the inherits of `Variable` and `Trackable`.
throw new NotImplementedException();
}
}
}

return slot_variables;
}

public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
{
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
{
return trackable;
}
else
{
return possible_res;
}
}

public static string get_full_name(Trackable variable)
{
// TODO: This state is not correct, the whole framework need to be updated in the future.
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
{
return "";
}
// skip the check of attribute `_save_slice_info` .

// TODO: Need to be revised!!!
Debug.Assert(variable is BaseResourceVariable);
return ((BaseResourceVariable)variable).Name;
}

public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
{
HashSet<int> checkpointed_trackables = new();
Dictionary<int, HashSet<int>> parents = new();
for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
{
var object_proto = object_graph_proto.Nodes[i];
// skip the process of registered saver.
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
{
checkpointed_trackables.Add(i);
}

foreach (var child_proto in object_proto.Children)
{
var child = child_proto.NodeId;
if (!parents.ContainsKey(child))
{
parents[child] = new HashSet<int>();
}

parents[child].Add(i);
}
}

Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
while (to_visit.Count > 0)
{
var trackable = to_visit.Dequeue();
if (!parents.ContainsKey(trackable)) continue;
var current_parents = parents[trackable];
foreach (var parent in current_parents)
{
checkpointed_trackables.Add(parent);
if (parents.ContainsKey(parent))
{
to_visit.Enqueue(parent);
}
}
parents.Remove(trackable);
}
// TODO: Complete it after supporting checkpoint.
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
// {
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
}

+ 5
- 0
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -0,0 +1,5 @@
namespace Tensorflow.Checkpoint;

public record class CheckpointOptions(
string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);

+ 64
- 0
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -0,0 +1,64 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Serilog.Debugging;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;

public class ObjectGraphView: TrackableView, ICloneable
{
protected IEnumerable<TrackableReference>? _attached_dependencies;
// TODO: attached_dependencies
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root)
{
_attached_dependencies = attached_dependencies;
}

public object Clone()
{
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__
return new ObjectGraphView(Root, _attached_dependencies);
}

public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value.
if (obj == Root && _attached_dependencies is not null)
{
res.AddRange(_attached_dependencies);
}

return res;
}
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
}
public IEnumerable<TrackableReference>? AttachedDependencies
{
get => _attached_dependencies;
}

public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
return base._descendants_with_paths();
}

// TODO: complete the implementation
public void serialize_object_graph(object? saveables_cache = null)
{
throw new NotImplementedException();
}
// TODO: complete the implementation
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null)
{
throw new NotImplementedException();
}
}

+ 255
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -0,0 +1,255 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint
{
internal record class TrackableData(
// A trackable in the root Trackable object graph.
Trackable trackable,
// The index at which the Trackable appears in TrackableObjectGraph.nodes.
int node_id,
// The BFS-generated path from the root object / used to generate readable checkpoint keys.
string object_name,
// A list of ObjectReference for each child connected to this Trackable.
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto,
// A list of SlotVariableReference to save to the object (only valid for Optimizer objects).
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto,
// The object to save to checkpoint. Usually this is the same as `trackable`,
// but can differ when the the caller wants to specify a different object to
// save. For example, when saving checkpoints asynchronously, variables are
// copied to the CPU. `object_to_save` is set as the copied variable.
Trackable object_to_save
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data);

var object_graph_proto = fill_object_graph_proto(trackable_data);

var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);

Dictionary<Tensor, object> feed_additions;
if(cache is null)
{
feed_additions = null;
serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures,
cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value);
}
else
{
feed_additions = null;
// TODO: deal with cache.
throw new NotFiniteNumberException();
}

CheckPointUtils.add_checkpoint_values_check(object_graph_proto);

return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
}

private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach(var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}
Dictionary<Trackable, int> node_ids = new();
for(int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}
var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
List<TrackableData> trackable_data = new();
foreach(var trackable in trackable_objects)
{
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new();
foreach(var child in graph_view.list_children(trackable))
{
children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{
NodeId = node_ids[child.Refer],
LocalName = child.Name
});
}
slot_variables.TryGetValue(trackable, out var slot_variable);
trackable_data.Add(new TrackableData(
trackable: trackable,
node_id: node_ids[trackable],
object_name: object_names[trackable],
children_proto: children_proto,
slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(),
object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map)
));
}
return (trackable_data, node_ids);
}

private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data)
{
TrackableObjectGraph object_graph_proto = new();
for(int i = 0; i < trackable_data.Count; i++)
{
var td = trackable_data[i];
Debug.Assert(td.node_id == i);
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
}
return object_graph_proto;
}

/// <summary>
/// Creates dictionary of tensors to checkpoint, and updates the proto.
/// </summary>
/// <param name="tensor_trackables"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
}
else
{
tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
trackable = td.object_to_save;
}
if(trackable is not null)
{
serialized_tensors[trackable] = tensor_dict;
}
else
{
serialized_tensors[Trackable.None] = tensor_dict;
}
}
return serialized_tensors;
}

private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
}
else
{
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
var maybe_tensor = pair.Value;
var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name);

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
{
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
}

if(object_graph_proto is not null)
{
object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{
Name = local_name,
CheckpointKey = checkpoint_key,
FullName = CheckPointUtils.get_full_name(trackable)
});
}
}
return tensor_dict;
}

/// <summary>
/// Gets tensors to serialize from a Trackable with legacy SaveableObjects.
/// </summary>
/// <param name="trackable_data"></param>
/// <param name="node_ids"></param>
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();
object_names[trackable_data.trackable] = trackable_data.object_name;
Dictionary<Trackable, Trackable> object_map = new();
object_map[trackable_data.trackable] = trackable_data.object_to_save;

var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map);
var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map,
call_with_mapped_captures, saveables_cache: null);
var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects);
return (trackable, trackable.serialize_to_tensors());
}

private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto)
{
Dictionary<string, IDictionary<string, Trackable>> registered_savers = new();
foreach(var pair in registered_trackables)
{
foreach(var td in pair.Value)
{
if (registered_savers.ContainsKey(pair.Key))
{
registered_savers[pair.Key] = new Dictionary<string, Trackable>();
}
else
{
registered_savers[pair.Key][td.object_name] = td.object_to_save;
}

var object_proto = object_graph_proto.Nodes[td.node_id];
// TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`.
}
}
return registered_savers;
}

private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data)
{
List<TrackableData> tensor_trackables = new();
List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder.
Dictionary<string, IList<TrackableData>> registered_trackables = new();

foreach(var td in trackable_data)
{
// TODO: deal with registration.
tensor_trackables.Add(td);
}
return (tensor_trackables, py_state_trackables, registered_trackables);
}
}
}

+ 223
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -0,0 +1,223 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Google.Protobuf;

namespace Tensorflow.Checkpoint;

public static class SaveUtilV1
{
public static (IDictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
IDictionary<Trackable, Trackable>? object_map = null)
{
// According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md,
// till now only internal registrations are allowed. So, we won't return a saver in this function.
// The implementation of this function should be updated if tensorflow update it.
Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new();
foreach (var pair in object_names)
{
var trackable = pair.Key;
var object_name = pair.Value;
var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip the registration process.

List<CheckpointFactoryData> current_list = new();
foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save))
{
// treat name as key_suffix.
var name = name_and_factory.Key;
var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name);
current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key));
}

checkpoint_factory_map[trackable] = current_list;
}

return (checkpoint_factory_map, null);
}

public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null)
{
if (to_graph is not null)
{
var g = to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
g.Exit();
return (named_saveable_objects, registered_savers);
}
else
{
using (new ops.NullContextManager())
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
}
}

public static (IList<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables);
var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph(
trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures,
saveables_cache);
CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers);
}

private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables)
{
TrackableObjectGraph object_graph_proto = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
var trackable = trackable_objects[i];
Debug.Assert(node_ids[trackable] == i);
TrackableObjectGraph.Types.TrackableObject object_proto;
if (slot_variables.TryGetValue(trackable, out var slots))
{
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots);
}
else
{
object_proto = new TrackableObjectGraph.Types.TrackableObject();
}
object_graph_proto.Nodes.Add(object_proto);
foreach (var child in graph_view.list_children(trackable))
{
object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{ NodeId = node_ids[child.Refer], LocalName = child.Name });
}
}

return object_graph_proto;
}

private static (IList<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(
IList<Trackable> trackable_objects,
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
bool call_with_mapped_captures, object? saveables_cache = null)
{
int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count);
for (int i = 0; i < cnt; i++)
{
Debug.Assert(node_ids[trackable_objects[i]] == i);
}

var (checkpoint_factory_map, unmmaped_registered_savers) =
get_checkpoint_factories_and_keys(object_names, object_map);
// skip the process of registered savers

var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map,
object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache);
return (named_saveable_objects, feed_additions, null);
}

public static (IList<MySaveableObject>, object?) generate_saveable_objects(
IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map,
TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
List<MySaveableObject> named_saveable_objects = new();
foreach (var pair in checkpoint_factory_map)
{
var trackable = pair.Key;
var factory_data_list = pair.Value;
bool fill_object_proto = object_graph_proto is not null && node_ids is not null;
TrackableObjectGraph.Types.TrackableObject object_proto = null!;
if (fill_object_proto)
{
object_proto = object_graph_proto.Nodes[node_ids[trackable]];
}

var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip cache

foreach (var factory_data in factory_data_list)
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;

// TODO: oneflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{
saveables.Add(s);
}
else
{
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key));
}

foreach (var saveable in saveables)
{
if (!saveable.name.Contains(key))
{
throw new AssertionError($"The object {trackable} produced a SaveableObject with name " +
$"'{saveable.name}' for attribute '{name}'. Expected a name" +
$" containing '{key}'.");
}
}
// skip the process of PythonState
named_saveable_objects.AddRange(saveables);
if(!fill_object_proto) continue;

// skip the process of `TrackableSaveable` because of lack of APIs.

object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
}
}

return (named_saveable_objects, null);
}
}

public record class CheckpointFactoryData
(
Maybe<BaseResourceVariable, MySaveableObject> factory,
string name,
string checkpoint_key
);

+ 16
- 0
src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint
{
internal static class SaveableCompat
{
public static string? get_saveable_name(Trackable cls_or_obj)
{
// TODO: implement it with Attribute.
return null;
}
}
}

+ 82
- 0
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -0,0 +1,82 @@
using System;
using Tensorflow.Train;
using System.Collections.Generic;
using System.IO;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow.Checkpoint;

public class TrackableView
{
protected WeakReference<Trackable> _root_ref;
public TrackableView(Trackable obj)
{
_root_ref = new WeakReference<Trackable>(obj);
}

public TrackableView(WeakReference<Trackable> obj)
{
_root_ref = obj;
}
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
obj._maybe_initialize_trackable();
Dictionary<string, Trackable> children = new();
// Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process.
foreach (var pair in obj._trackable_children(save_type, cache))
{
children[pair.Key] = pair.Value;
}
return children;
}
public Trackable Root
{
get
{
if (_root_ref.TryGetTarget(out Trackable res))
{
return res;
}
else
{
throw new InvalidDataException(
"Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor.");
}
}
}
/// <summary>
/// Returns a list of all nodes and its paths from self.root using a breadth first traversal.
/// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths
/// </summary>
protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
{
List<Trackable> bfs_sorted = new();
Queue<Trackable> to_visit = new();
to_visit.Enqueue(Root);
Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new();
node_paths[this.Root] = new List<TrackableReference>();
while (!to_visit.empty())
{
var current_trackable = to_visit.Dequeue();
bfs_sorted.Add(current_trackable);
var children_dict = this.children(current_trackable);
foreach (var name in children_dict.Keys)
{
var dependency = children_dict[name];
if (!node_paths.ContainsKey(dependency))
{
var list = new List<TrackableReference>(node_paths[current_trackable]);
list.Add(new TrackableReference(name, dependency));
node_paths[dependency] = list;
to_visit.Enqueue(dependency);
}
}
}

return (bfs_sorted, node_paths);
}
}

+ 195
- 0
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -0,0 +1,195 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Checkpoint;

/// <summary>
/// Saves and restores a `Trackable` object and its dependencies.
/// </summary>
public class TrackableSaver
{
private ObjectGraphView _graph_view;
private Tensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
private Dictionary<Trackable, Trackable>? _object_map = null;
private object? _cache = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);

// TODO: cache.

if(object_graph_tensor is null)
{
tf.device("/cpu:0");
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
}
else
{
feed_additions[object_graph_tensor] = graph_proto.ToByteArray();
}
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (!serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>();
}
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] { save_op }))
{
_cached_save_operation = array_ops.identity(file_prefix);
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);

// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] {save_op} ))
{
_cached_save_operation = array_ops.identity(tf.constant(file_prefix));
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};

if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}

return run_save();
}

// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}

Dictionary<Tensor, object> feed_dict = new();
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
string file_prefix_to_save;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
_file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
}

object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
file_prefix_to_save = "";
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null;
file_prefix_to_save = file_prefix;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options);

if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}

if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}

+ 540
- 0
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -0,0 +1,540 @@
using System;
using System.Buffers.Text;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.OptimizerOptions.Types;
using static Tensorflow.Binding;
using System.Text.RegularExpressions;
using System.Linq;
using Tensorflow.Operations;
using Tensorflow.Training;
using Tensorflow.Graphs;
using System.Xml.Linq;
using System.Diagnostics;
using RestoreFunc = System.Func<object, object>;

namespace Tensorflow.Checkpoint
{
public class Maybe<TA, TB>
{
private TA? _valueA = default(TA);
private TB? _valueB = default(TB);
private Type _type;
private bool _assignedTA;
public Maybe(TA value)
{
_valueA = value;
_type= typeof(TA);
_assignedTA = true;
}
public Maybe(TB value)
{
_valueB = value;
_type = typeof(TB);
_assignedTA = false;
}

public Type DataType => _type;

/// <summary>
/// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned.
/// It returns
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="res"></param>
/// <returns></returns>
public bool TryGet<T>(out T? res)
{
if(_valueA is T && _valueB is not T)
{
res = (T)(object)_valueA;
return _assignedTA;
}
else if(_valueA is not T && _valueB is T)
{
res = (T)(object)_valueB;
return !_assignedTA;
}
res = default(T);
return false;
}

public bool IsTypeOrDeriveFrom<T>()
{
if (_valueA is T && _valueB is not T)
{
return _assignedTA;
}
else if (_valueA is not T && _valueB is T)
{
return !_assignedTA;
}
else if (_valueA is T && _valueB is T)
{
return true;
}
else
{
return false;
}
}

public T GetValue<T>()
{
if (_valueA is T && _valueB is not T)
{
return (T)(object)_valueA;
}
else if (_valueA is not T && _valueB is T)
{
return (T)(object)_valueB;
}
else if (_valueA is T && _valueB is T)
{
throw new TypeError("The type is vague, this is always because TA and TB both derive from T.");
}
else
{
throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}.");
}
}

public static implicit operator Maybe<TA, TB>(TA a)
{
return new Maybe<TA, TB>(a);
}
public static implicit operator Maybe<TA, TB>(TB b)
{
return new Maybe<TA, TB>(b);
}
}
internal class SingleDeviceSaver
{
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict;
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict;
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict)
{
_tensor_slice_dict = tensor_slice_dict.ToDictionary(
x => x.Key, x => x.Value.ToDictionary(
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value))
as IDictionary<string, Maybe<Tensor, SaveSpec>>);
}
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<Tensor> tensors = new();
List<string> slice_specs = new();
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
{
var tensor_value = spec.tensor;
if (tensor_value is not null)
{
tensor_names.Add(spec.name);
tensors.Add(tensor_value);
slice_specs.Add(spec.slice_spec);
}
}
else
{
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_names.Add(checkpoint_key);
tensors.Add(tensor);
slice_specs.Add(slice_spec);
}
}
}
// TODO: specify the device.
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray());
}

public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options);

public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}
List<string> tensor_names = new();
List<TF_DataType> tensor_dtypes = new();
List<string> slice_specs = new();

foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice in tensor_slices)
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.TryGet<SaveSpec>(out var spec))
{
tensor_dtypes.Add(spec.dtype);
slice_specs.Add(spec.slice_spec);
tensor_names.Add(spec.name);
}
else
{
var tensor = maybe_tensor.GetValue<Tensor>();
tensor_dtypes.Add(tensor.dtype);
slice_specs.Add(slice_spec);
tensor_names.Add(checkpoint_key);
}
}
}

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
foreach(var pair in _tensor_slice_dict)
{
var checkpoint_key = pair.Key;
var tensor_slices = pair.Value;
foreach(var slice_spec in tensor_slices.Keys)
{
var restored_tensor = restored_tensors[idx++];
if (!restored_tensor_dict.ContainsKey(checkpoint_key))
{
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor;
}
}
return restored_tensor_dict;
}

public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix));
}
/// <summary>
/// Saves checkpoints directly from multiple devices.
/// Note that this is a low-level utility which stores Tensors in the keys
/// specified by `SaveableObject`s.Higher-level utilities for object-based
/// checkpointing are built on top of it.
/// </summary>
public class MultiDeviceSaver
{
private Dictionary<string, SingleDeviceSaver> _single_device_savers;
private IDictionary<string, (RestoreFunc, RestoreFunc)> _registered_savers;
private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn;
private Dictionary<RestoreFunc, IList<(string, string)>> _restore_fn_to_keys;
/// <summary>
///
/// </summary>
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>();
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
var obj = pair.Key;
var tensor_dict = pair.Value;
RestoreFunc restore_fn;
if(obj == Trackable.None)
{
restore_fn = new RestoreFunc(x => null);
}
else
{
restore_fn = new RestoreFunc(x =>
{
if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>)
{
return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>);
}
throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}.");
});
}

foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.TryGet<Tensor>(out var t))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = t;
}
else
{
spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>();
}

foreach(var spec in spec_to_tensor)
{
var slice_spec = spec.Key;
var tensor = spec.Value;
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec)))
{
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " +
$"slice spec. This is invalid because one will overwrite the " +
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation.");
}
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
}
}

_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value));

_registered_savers = new Dictionary<string, (RestoreFunc, RestoreFunc)>();
if(registered_savers is not null && registered_savers.Count > 0)
{
// TODO: complete the implementation.
throw new NotImplementedException();
}
}

public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
{
if(options is null)
{
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
constant_op.constant(".part"), constant_op.constant("_temp/part"));
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));

Operation save_fn()
{
List<Tensor> saved_prefixes= new();
foreach(var saver in _registered_savers)
{
// TODO: implementi it later.
throw new NotImplementedException();
}

int num_shards = _single_device_savers.Count;
List<Operation> sharded_saves = new();
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards");
string? last_device = null;
int shard = 0;
foreach(var pair in _single_device_savers.OrderBy(x => x.Key))
{
var device = pair.Key;
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
}
}

if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1)
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return save_fn();
}
}

public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options);

public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
options = new CheckpointOptions();
}

IDictionary<string, Operation> restore_func()
{
Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new();
Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count);
Dictionary<string, Operation> restore_ops = new();

foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key))
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
}
else
{
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}

foreach(var item in _registered_savers)
{
throw new NotImplementedException();
}
return restore_ops;
}

// TODO: complete the implementation. Currently skip it because of lack of API.
bool has_custom_device_saver = false;

if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver))
{
// TODO: implement it. Currently `autograph` does not support the function with non parameter.
throw new NotImplementedException();
}
else
{
return restore_func();
}
}

/// <summary>
/// Serializes to a SaverDef referencing the current graph.
/// </summary>
public SaverDef to_proto()
{
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename");
var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING);
var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING);
var save_tensor = traced_save_func(filename_tensor);
var restore_op = traced_restore_func(filename_tensor).op;
return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
SaveTensorName = save_tensor.name,
RestoreOpName = restore_op.name,
Version = SaverDef.Types.CheckpointFormatVersion.V2
};
}

private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
{
return array_ops.identity(file_prefix);
}
}

private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(restore_op.Values.ToArray()))
{
return array_ops.identity(file_prefix);
}
}

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach (var saveable in saveables)
{
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });
serialized_tensors[trackable] = trackable.serialize_to_tensors();
}
return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures);
}

private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name)
{
return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") });
}
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards)
{
return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards);
}
}
}

+ 69
- 1
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Tensorflow.Train;


namespace Tensorflow namespace Tensorflow
{ {
@@ -90,4 +91,71 @@ namespace Tensorflow
Dispose(false); Dispose(false);
} }
} }
}

public abstract class DisposableTrackableObject: Trackable, IDisposable
{
protected IntPtr _handle;
protected bool _disposed;

protected DisposableTrackableObject()
{ }

protected DisposableTrackableObject(IntPtr handle)
=> _handle = handle;

private void Dispose(bool disposing)
{
if (_disposed)
return;

//first handle managed, they might use the unmanaged resources.
if (disposing)
{
// dispose managed state (managed objects).
DisposeManagedResources();
}

// free unmanaged memory
if (_handle != IntPtr.Zero)
{
// Call the appropriate methods to clean up
// unmanaged resources here.
// If disposing is false,
// only the following code is executed.
DisposeUnmanagedResources(_handle);
_handle = IntPtr.Zero;
}

// Note disposing has been done.
_disposed = true;
}

/// <summary>
/// Dispose any managed resources.
/// </summary>
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks>
protected virtual void DisposeManagedResources()
{ }

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected abstract void DisposeUnmanagedResources(IntPtr handle);

public void Dispose()
{
Dispose(true);
// This object will be cleaned up by the Dispose method.
// Therefore, you should call GC.SupressFinalize to
// take this object off the finalization queue
// and prevent finalization code for this object
// from executing a second time.
GC.SuppressFinalize(this);
}

~DisposableTrackableObject()
{
Dispose(false);
}
}
}

+ 31
- 0
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Contexts;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
internal class execute
{
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{
var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx));
var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray());
}
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
string device_name = ctx.DeviceName;

ctx.ensure_initialized();
var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs);

return tensors;
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Exceptions/AssertionError.cs View File

@@ -0,0 +1,14 @@
namespace Tensorflow.Exceptions;

public class AssertionError : TensorflowException
{
public AssertionError() : base()
{

}

public AssertionError(string message) : base(message)
{

}
}

+ 85
- 1
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -304,7 +304,7 @@ namespace Tensorflow
} }
} }


private static OpList stripped_op_list_for_graph(GraphDef graph_def)
public static OpList stripped_op_list_for_graph(GraphDef graph_def)
{ {
var used_ops = ops_used_by_graph_def(graph_def); var used_ops = ops_used_by_graph_def(graph_def);


@@ -345,5 +345,89 @@ namespace Tensorflow


return used_ops.ToArray(); return used_ops.ToArray();
} }

private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value)
{
foreach (var attr_def in op_def.Attr)
{
if (attr_def.Name == attr_name)
{
if (attr_def.DefaultValue is null) return false;
// TODO: add new c_api `EqualAttrValueWrapper` and complete the check.
return true;
}
}

return false;
}

public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def)
{
Dictionary<string, FunctionDef> op_name_to_function = new();
foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
op_name_to_function[function_def.Signature.Name] = function_def;
}

Action<NodeDef> _strip_node_default_valued_attrs = (node_def) =>
{
if (op_name_to_function.ContainsKey(node_def.Op)) return;

var op_def = op_def_registry.GetOpDef(node_def.Op);
if(op_def is null) return;

HashSet<string> attrs_to_strip = new();
foreach (var attr in node_def.Attr)
{
if (is_default_attr_value(op_def, attr.Key, attr.Value))
{
attrs_to_strip.Add(attr.Key);
}
}

foreach (var attr in attrs_to_strip)
{
node_def.Attr.Remove(attr);
}
};

foreach (var node_def in meta_graph_def.GraphDef.Node)
{
_strip_node_default_valued_attrs(node_def);
}

foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
foreach (var function_node_def in function_def.NodeDef)
{
_strip_node_default_valued_attrs(function_node_def);
}
}

meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
}

/// <summary>
/// Extract the Op name from a Tensor name.
/// </summary>
/// <param name="tensor_name"></param>
/// <returns></returns>
public static string op_name(string tensor_name)
{
if (string.IsNullOrEmpty(tensor_name))
{
throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}.");
}

if (tensor_name.StartsWith("^"))
{
tensor_name = tensor_name.Substring(1);
}
if (tensor_name.Contains(":"))
{
return tensor_name.Split(':')[0];
}
return tensor_name;
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Functions namespace Tensorflow.Functions
@@ -10,7 +11,7 @@ namespace Tensorflow.Functions
/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
public class ConcreteFunction
public class ConcreteFunction: Trackable
{ {
FuncGraph func_graph; FuncGraph func_graph;
ForwardBackwardCall forward_backward; ForwardBackwardCall forward_backward;


+ 9
- 2
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -1,16 +1,23 @@
using System; using System;
using Tensorflow.Train;


namespace Tensorflow namespace Tensorflow
{ {
public class Function
public class Function: Trackable
{ {
#pragma warning disable CS0169 // The field 'Function._handle' is never used #pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle; private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used #pragma warning restore CS0169 // The field 'Function._handle' is never used

public string Name { get; set; }
public Function() public Function()
{ {


} }
public Function(string name)
{
Name = name;
}
} }
} }

+ 32
- 14
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Diagnostics;
using System.Linq; using System.Linq;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -6,14 +7,14 @@ namespace Tensorflow.Graphs
{ {
public class AutoGraph public class AutoGraph
{ {
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32)
{ {
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; string func_name = $"{func.Method.Name}_{ops.uid_function()}";


var graph = new FuncGraph(func_name); var graph = new FuncGraph(func_name);
graph.as_default(); graph.as_default();


var input = tf.placeholder(tf.int32);
var input = tf.placeholder(dtype);
var output = func(input); var output = func(input);


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -26,25 +27,33 @@ namespace Tensorflow.Graphs


return (Tensor input) => return (Tensor input) =>
{ {
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
}
using (var s = tf.Session(input.graph))
{
var output = func(input);
return output;
}
}; };
} }


public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes)
{ {
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; string func_name = $"{func.Method.Name}_{ops.uid_function()}";


var graph = new FuncGraph(func_name); var graph = new FuncGraph(func_name);
graph.as_default(); graph.as_default();


var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32);
var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32);
var output = func(input1, input2); var output = func(input1, input2);


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -56,13 +65,22 @@ namespace Tensorflow.Graphs
return (Tensor a, Tensor b) => return (Tensor a, Tensor b) =>
{ {
var result = tf.Runner.TFE_Execute(tf.Context,
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName, tf.Context.DeviceName,
func_name, func_name,
new[] { a, b }, new[] { a, b },
null, null,
1); 1);
return result[0];
return result[0];
}
using (var s = tf.Session(a.graph))
{
Debug.Assert(a.graph == b.graph);
var output = func(a, b);
return output;
}
}; };
} }
} }


+ 7
- 4
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs View File

@@ -1,9 +1,12 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.ArgsDefinition { namespace Tensorflow.Keras.ArgsDefinition {
public class ELUArgs : LayerArgs {
public float Alpha { get; set; } = 0.1f;
}
public class ELUArgs : AutoSerializeLayerArgs
{
[JsonProperty("alpha")]
public float Alpha { get; set; } = 0.1f;
}
} }

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs View File

@@ -1,14 +1,16 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class LeakyReLuArgs : LayerArgs
public class LeakyReLuArgs : AutoSerializeLayerArgs
{ {
/// <summary> /// <summary>
/// Negative slope coefficient. /// Negative slope coefficient.
/// </summary> /// </summary>
[JsonProperty("alpha")]
public float Alpha { get; set; } = 0.3f; public float Alpha { get; set; } = 0.3f;
} }
} }

+ 7
- 4
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs View File

@@ -1,9 +1,12 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.ArgsDefinition { namespace Tensorflow.Keras.ArgsDefinition {
public class SoftmaxArgs : LayerArgs {
public Axis axis { get; set; } = -1;
}
public class SoftmaxArgs : AutoSerializeLayerArgs
{
[JsonProperty("axis")]
public Axis axis { get; set; } = -1;
}
} }

+ 4
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs View File

@@ -1,3 +1,5 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class AttentionArgs : BaseDenseAttentionArgs public class AttentionArgs : BaseDenseAttentionArgs
@@ -6,6 +8,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// <summary> /// <summary>
/// If `true`, will create a scalar variable to scale the attention scores. /// If `true`, will create a scalar variable to scale the attention scores.
/// </summary> /// </summary>
[JsonProperty("use_scale")]
public bool use_scale { get; set; } = false; public bool use_scale { get; set; } = false;


/// <summary> /// <summary>
@@ -14,6 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// and key vectors. `"concat"` refers to the hyperbolic tangent of the /// and key vectors. `"concat"` refers to the hyperbolic tangent of the
/// concatenation of the query and key vectors. /// concatenation of the query and key vectors.
/// </summary> /// </summary>
[JsonProperty("score_mode")]
public string score_mode { get; set; } = "dot"; public string score_mode { get; set; } = "dot";


} }

+ 4
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs View File

@@ -1,6 +1,8 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class BaseDenseAttentionArgs : LayerArgs
public class BaseDenseAttentionArgs : AutoSerializeLayerArgs
{ {


/// <summary> /// <summary>
@@ -14,6 +16,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// Float between 0 and 1. Fraction of the units to drop for the /// Float between 0 and 1. Fraction of the units to drop for the
/// attention scores. /// attention scores.
/// </summary> /// </summary>
[JsonProperty("dropout")]
public float dropout { get; set; } = 0f; public float dropout { get; set; } = 0f;


} }

+ 19
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs View File

@@ -1,22 +1,40 @@
using Newtonsoft.Json;
using System; using System;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class MultiHeadAttentionArgs : LayerArgs
public class MultiHeadAttentionArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("num_heads")]
public int NumHeads { get; set; } public int NumHeads { get; set; }
[JsonProperty("key_dim")]
public int KeyDim { get; set; } public int KeyDim { get; set; }
[JsonProperty("value_dim")]
public int? ValueDim { get; set; } = null; public int? ValueDim { get; set; } = null;
[JsonProperty("dropout")]
public float Dropout { get; set; } = 0f; public float Dropout { get; set; } = 0f;
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true; public bool UseBias { get; set; } = true;
[JsonProperty("output_shape")]
public Shape OutputShape { get; set; } = null; public Shape OutputShape { get; set; } = null;
[JsonProperty("attention_axes")]
public Shape AttentionAxis { get; set; } = null; public Shape AttentionAxis { get; set; } = null;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } = null; public IRegularizer KernelRegularizer { get; set; } = null;
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } = null; public IRegularizer BiasRegularizer { get; set; } = null;
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } = null; public Action KernelConstraint { get; set; } = null;
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } = null; public Action BiasConstraint { get; set; } = null;
[JsonProperty("activity_regularizer")]
public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }

// TODO: Add `key_shape`, `value_shape`, `query_shape`.
} }
} }

+ 25
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -0,0 +1,25 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
/// <summary>
/// This class has nothing but the attributes different from `LayerArgs`.
/// It's used to serialize the model to `tf` format.
/// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`,
/// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`.
/// </summary>
public class AutoSerializeLayerArgs: LayerArgs
{
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
}

+ 37
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs View File

@@ -1,31 +1,65 @@
using System;
using Newtonsoft.Json;
using System;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class ConvolutionalArgs : LayerArgs
public class ConvolutionalArgs : AutoSerializeLayerArgs
{ {
public int Rank { get; set; } = 2; public int Rank { get; set; } = 2;
[JsonProperty("filters")]
public int Filters { get; set; } public int Filters { get; set; }
public int NumSpatialDims { get; set; } = Unknown; public int NumSpatialDims { get; set; } = Unknown;
[JsonProperty("kernel_size")]
public Shape KernelSize { get; set; } = 5; public Shape KernelSize { get; set; } = 5;


/// <summary> /// <summary>
/// specifying the stride length of the convolution. /// specifying the stride length of the convolution.
/// </summary> /// </summary>
[JsonProperty("strides")]
public Shape Strides { get; set; } = (1, 1); public Shape Strides { get; set; } = (1, 1);
[JsonProperty("padding")]
public string Padding { get; set; } = "valid"; public string Padding { get; set; } = "valid";
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
[JsonProperty("dilation_rate")]
public Shape DilationRate { get; set; } = (1, 1); public Shape DilationRate { get; set; } = (1, 1);
[JsonProperty("groups")]
public int Groups { get; set; } = 1; public int Groups { get; set; } = 1;
public Activation Activation { get; set; } public Activation Activation { get; set; }
private string _activationName;
[JsonProperty("activation")]
public string ActivationName
{
get
{
if (string.IsNullOrEmpty(_activationName))
{
return Activation.Method.Name;
}
else
{
return _activationName;
}
}
set
{
_activationName = value;
}
}
[JsonProperty("use_bias")]
public bool UseBias { get; set; } public bool UseBias { get; set; }
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } public IRegularizer KernelRegularizer { get; set; }
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } public IRegularizer BiasRegularizer { get; set; }
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } public Action KernelConstraint { get; set; }
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } public Action BiasConstraint { get; set; }
} }
} }

+ 41
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs View File

@@ -1,13 +1,18 @@
using System;
using Newtonsoft.Json;
using System;
using System.Xml.Linq;
using Tensorflow.Operations.Initializers;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: `activity_regularizer`
public class DenseArgs : LayerArgs public class DenseArgs : LayerArgs
{ {
/// <summary> /// <summary>
/// Positive integer, dimensionality of the output space. /// Positive integer, dimensionality of the output space.
/// </summary> /// </summary>
[JsonProperty("units")]
public int Units { get; set; } public int Units { get; set; }


/// <summary> /// <summary>
@@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition
/// </summary> /// </summary>
public Activation Activation { get; set; } public Activation Activation { get; set; }


private string _activationName;
[JsonProperty("activation")]
public string ActivationName
{
get
{
if (string.IsNullOrEmpty(_activationName))
{
return Activation.Method.Name;
}
else
{
return _activationName;
}
}
set
{
_activationName = value;
}
}

/// <summary> /// <summary>
/// Whether the layer uses a bias vector. /// Whether the layer uses a bias vector.
/// </summary> /// </summary>
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true; public bool UseBias { get; set; } = true;


/// <summary> /// <summary>
/// Initializer for the `kernel` weights matrix. /// Initializer for the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;


/// <summary> /// <summary>
/// Initializer for the bias vector. /// Initializer for the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;


/// <summary> /// <summary>
/// Regularizer function applied to the `kernel` weights matrix. /// Regularizer function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } public IRegularizer KernelRegularizer { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the bias vector. /// Regularizer function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } public IRegularizer BiasRegularizer { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the `kernel` weights matrix. /// Constraint function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } public Action KernelConstraint { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the bias vector. /// Constraint function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } public Action BiasConstraint { get; set; }

[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
} }
} }

src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs View File

@@ -1,9 +1,10 @@
using Newtonsoft.Json;
using System; using System;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition
namespace Tensorflow.Keras.ArgsDefinition.Core
{ {
public class EinsumDenseArgs : LayerArgs
public class EinsumDenseArgs : AutoSerializeLayerArgs
{ {
/// <summary> /// <summary>
/// An equation describing the einsum to perform. This equation must /// An equation describing the einsum to perform. This equation must
@@ -11,6 +12,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
/// expression sequence. /// expression sequence.
/// </summary> /// </summary>
[JsonProperty("equation")]
public string Equation { get; set; } public string Equation { get; set; }


/// <summary> /// <summary>
@@ -19,6 +21,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// None for any dimension that is unknown or can be inferred from the input /// None for any dimension that is unknown or can be inferred from the input
/// shape. /// shape.
/// </summary> /// </summary>
[JsonProperty("output_shape")]
public Shape OutputShape { get; set; } public Shape OutputShape { get; set; }


/// <summary> /// <summary>
@@ -26,41 +29,70 @@ namespace Tensorflow.Keras.ArgsDefinition
/// Each character in the `bias_axes` string should correspond to a character /// Each character in the `bias_axes` string should correspond to a character
/// in the output portion of the `equation` string. /// in the output portion of the `equation` string.
/// </summary> /// </summary>
[JsonProperty("bias_axes")]
public string BiasAxes { get; set; } = null; public string BiasAxes { get; set; } = null;


/// <summary> /// <summary>
/// Activation function to use. /// Activation function to use.
/// </summary> /// </summary>
public Activation Activation { get; set; } public Activation Activation { get; set; }
private string _activationName;
[JsonProperty("activation")]
public string ActivationName
{
get
{
if (string.IsNullOrEmpty(_activationName))
{
return Activation.Method.Name;
}
else
{
return _activationName;
}
}
set
{
_activationName = value;
}
}


/// <summary> /// <summary>
/// Initializer for the `kernel` weights matrix. /// Initializer for the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;


/// <summary> /// <summary>
/// Initializer for the bias vector. /// Initializer for the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;


/// <summary> /// <summary>
/// Regularizer function applied to the `kernel` weights matrix. /// Regularizer function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } public IRegularizer KernelRegularizer { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the bias vector. /// Regularizer function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } public IRegularizer BiasRegularizer { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the `kernel` weights matrix. /// Constraint function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } public Action KernelConstraint { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the bias vector. /// Constraint function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } public Action BiasConstraint { get; set; }
[JsonProperty("activity_regularizer")]
public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }
} }
} }

+ 13
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs View File

@@ -1,11 +1,22 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class EmbeddingArgs : LayerArgs
public class EmbeddingArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("input_dim")]
public int InputDim { get; set; } public int InputDim { get; set; }
[JsonProperty("output_dim")]
public int OutputDim { get; set; } public int OutputDim { get; set; }
[JsonProperty("mask_zero")]
public bool MaskZero { get; set; } public bool MaskZero { get; set; }
[JsonProperty("input_length")]
public int InputLength { get; set; } = -1; public int InputLength { get; set; } = -1;
[JsonProperty("embeddings_initializer")]
public IInitializer EmbeddingsInitializer { get; set; } public IInitializer EmbeddingsInitializer { get; set; }
[JsonProperty("activity_regularizer")]
public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }

// TODO: `embeddings_regularizer`, `embeddings_constraint`.
} }
} }

+ 15
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs View File

@@ -1,9 +1,22 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Tensorflow.Keras.Common;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class InputLayerArgs : LayerArgs public class InputLayerArgs : LayerArgs
{ {
[JsonIgnore]
public Tensor InputTensor { get; set; } public Tensor InputTensor { get; set; }
public bool Sparse { get; set; }
[JsonProperty("sparse")]
public virtual bool Sparse { get; set; }
[JsonProperty("ragged")]
public bool Ragged { get; set; } public bool Ragged { get; set; }
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
} }
} }

+ 0
- 16
src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs View File

@@ -1,16 +0,0 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition {
public class Cropping2DArgs : LayerArgs {
/// <summary>
/// channel last: (b, h, w, c)
/// channels_first: (b, c, h, w)
/// </summary>
public enum DataFormat { channels_first = 0, channels_last = 1 }
/// <summary>
/// Accept: int[1][2], int[1][1], int[2][2]
/// </summary>
public NDArray cropping { get; set; }
public DataFormat data_format { get; set; } = DataFormat.channels_last;
}
}

+ 0
- 16
src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs View File

@@ -1,16 +0,0 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition {
public class Cropping3DArgs : LayerArgs {
/// <summary>
/// channel last: (b, h, w, c)
/// channels_first: (b, c, h, w)
/// </summary>
public enum DataFormat { channels_first = 0, channels_last = 1 }
/// <summary>
/// Accept: int[1][3], int[1][1], int[3][2]
/// </summary>
public NDArray cropping { get; set; }
public DataFormat data_format { get; set; } = DataFormat.channels_last;
}
}

+ 0
- 10
src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs View File

@@ -1,10 +0,0 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition {
public class CroppingArgs : LayerArgs {
/// <summary>
/// Accept length 1 or 2
/// </summary>
public NDArray cropping { get; set; }
}
}

+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class DataAdapterArgs
public class DataAdapterArgs: IKerasConfig
{ {
public Tensor X { get; set; } public Tensor X { get; set; }
public Tensor Y { get; set; } public Tensor Y { get; set; }


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class DataHandlerArgs
public class DataHandlerArgs: IKerasConfig
{ {
public Tensor X { get; set; } public Tensor X { get; set; }
public Tensor Y { get; set; } public Tensor Y { get; set; }


+ 17
- 14
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -1,51 +1,54 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class LayerArgs
[JsonObject(MemberSerialization.OptIn)]
public class LayerArgs: IKerasConfig
{ {
/// <summary> /// <summary>
/// Indicates whether the layer's weights are updated during training /// Indicates whether the layer's weights are updated during training
/// and whether the layer's updates are run during training. /// and whether the layer's updates are run during training.
/// </summary> /// </summary>
public bool Trainable { get; set; } = true;

public string Name { get; set; }
public virtual bool Trainable { get; set; } = true;
public virtual string Name { get; set; }


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;
public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;


/// <summary> /// <summary>
/// Whether the `call` method can be used to build a TF graph without issues. /// Whether the `call` method can be used to build a TF graph without issues.
/// This attribute has no effect if the model is created using the Functional /// This attribute has no effect if the model is created using the Functional
/// API. Instead, `model.dynamic` is determined based on the internal layers. /// API. Instead, `model.dynamic` is determined based on the internal layers.
/// </summary> /// </summary>
public bool Dynamic { get; set; } = false;
public virtual bool Dynamic { get; set; } = false;


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public Shape InputShape { get; set; }
public virtual Shape InputShape { get; set; }


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public Shape BatchInputShape { get; set; }
public virtual Shape BatchInputShape { get; set; }


public int BatchSize { get; set; } = -1;
public virtual int BatchSize { get; set; } = -1;


/// <summary> /// <summary>
/// Initial weight values. /// Initial weight values.
/// </summary> /// </summary>
public float[] Weights { get; set; }
public virtual float[] Weights { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the output of the layer(its "activation"). /// Regularizer function applied to the output of the layer(its "activation").
/// </summary> /// </summary>
public IRegularizer ActivityRegularizer { get; set; }
public virtual IRegularizer ActivityRegularizer { get; set; }


public bool Autocast { get; set; }
public virtual bool Autocast { get; set; }


public bool IsFromConfig { get; set; }
public virtual bool IsFromConfig { get; set; }
} }
} }

+ 0
- 6
src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs View File

@@ -1,6 +0,0 @@
namespace Tensorflow.Keras.ArgsDefinition.Lstm
{
public class LSTMCellArgs : LayerArgs
{
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs View File

@@ -4,6 +4,7 @@ using System.Text;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: complete the implementation
public class MergeArgs : LayerArgs public class MergeArgs : LayerArgs
{ {
public Tensors Inputs { get; set; } public Tensors Inputs { get; set; }


+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class NodeArgs
public class NodeArgs: IKerasConfig
{ {
public ILayer[] InboundLayers { get; set; } public ILayer[] InboundLayers { get; set; }
public int[] NodeIndices { get; set; } public int[] NodeIndices { get; set; }


+ 18
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs View File

@@ -1,21 +1,37 @@
using static Tensorflow.Binding;
using Newtonsoft.Json;
using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class BatchNormalizationArgs : LayerArgs
public class BatchNormalizationArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("axis")]
public Shape Axis { get; set; } = -1; public Shape Axis { get; set; } = -1;
[JsonProperty("momentum")]
public float Momentum { get; set; } = 0.99f; public float Momentum { get; set; } = 0.99f;
[JsonProperty("epsilon")]
public float Epsilon { get; set; } = 1e-3f; public float Epsilon { get; set; } = 1e-3f;
[JsonProperty("center")]
public bool Center { get; set; } = true; public bool Center { get; set; } = true;
[JsonProperty("scale")]
public bool Scale { get; set; } = true; public bool Scale { get; set; } = true;
[JsonProperty("beta_initializer")]
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
[JsonProperty("gamma_initializer")]
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
[JsonProperty("moving_mean_initializer")]
public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer;
[JsonProperty("moving_variance_initializer")]
public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer;
[JsonProperty("beta_regularizer")]
public IRegularizer BetaRegularizer { get; set; } public IRegularizer BetaRegularizer { get; set; }
[JsonProperty("gamma_regularizer")]
public IRegularizer GammaRegularizer { get; set; } public IRegularizer GammaRegularizer { get; set; }
// TODO: `beta_constraint` and `gamma_constraint`.
[JsonProperty("renorm")]
public bool Renorm { get; set; } public bool Renorm { get; set; }
// TODO: `renorm_clipping` and `virtual_batch_size`.
[JsonProperty("renorm_momentum")]
public float RenormMomentum { get; set; } = 0.99f; public float RenormMomentum { get; set; } = 0.99f;
} }
} }

+ 13
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs View File

@@ -1,16 +1,27 @@
using static Tensorflow.Binding;
using Newtonsoft.Json;
using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class LayerNormalizationArgs : LayerArgs
public class LayerNormalizationArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("axis")]
public Axis Axis { get; set; } = -1; public Axis Axis { get; set; } = -1;
[JsonProperty("epsilon")]
public float Epsilon { get; set; } = 1e-3f; public float Epsilon { get; set; } = 1e-3f;
[JsonProperty("center")]
public bool Center { get; set; } = true; public bool Center { get; set; } = true;
[JsonProperty("scale")]
public bool Scale { get; set; } = true; public bool Scale { get; set; } = true;
[JsonProperty("beta_initializer")]
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
[JsonProperty("gamma_initializer")]
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
[JsonProperty("beta_regularizer")]
public IRegularizer BetaRegularizer { get; set; } public IRegularizer BetaRegularizer { get; set; }
[JsonProperty("gamma_regularizer")]
public IRegularizer GammaRegularizer { get; set; } public IRegularizer GammaRegularizer { get; set; }

// TODO: `beta_constraint` and `gamma_constraint`.
} }
} }

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class OptimizerV2Args
public class OptimizerV2Args: IKerasConfig
{ {
public string Name { get; set; } public string Name { get; set; }
public float LearningRate { get; set; } = 0.001f; public float LearningRate { get; set; } = 0.001f;


+ 8
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class Pooling1DArgs : LayerArgs
public class Pooling1DArgs : AutoSerializeLayerArgs
{ {
/// <summary> /// <summary>
/// The pooling function to apply, e.g. `tf.nn.max_pool2d`. /// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
@@ -10,11 +12,13 @@
/// <summary> /// <summary>
/// specifying the size of the pooling window. /// specifying the size of the pooling window.
/// </summary> /// </summary>
[JsonProperty("pool_size")]
public int PoolSize { get; set; } public int PoolSize { get; set; }


/// <summary> /// <summary>
/// specifying the strides of the pooling operation. /// specifying the strides of the pooling operation.
/// </summary> /// </summary>
[JsonProperty("strides")]
public int Strides { public int Strides {
get { return _strides.HasValue ? _strides.Value : PoolSize; } get { return _strides.HasValue ? _strides.Value : PoolSize; }
set { _strides = value; } set { _strides = value; }
@@ -24,11 +28,13 @@
/// <summary> /// <summary>
/// The padding method, either 'valid' or 'same'. /// The padding method, either 'valid' or 'same'.
/// </summary> /// </summary>
[JsonProperty("padding")]
public string Padding { get; set; } = "valid"; public string Padding { get; set; } = "valid";


/// <summary> /// <summary>
/// one of `channels_last` (default) or `channels_first`. /// one of `channels_last` (default) or `channels_first`.
/// </summary> /// </summary>
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
} }
} }

+ 8
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class Pooling2DArgs : LayerArgs
public class Pooling2DArgs : AutoSerializeLayerArgs
{ {
/// <summary> /// <summary>
/// The pooling function to apply, e.g. `tf.nn.max_pool2d`. /// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
@@ -10,21 +12,25 @@
/// <summary> /// <summary>
/// specifying the size of the pooling window. /// specifying the size of the pooling window.
/// </summary> /// </summary>
[JsonProperty("pool_size")]
public Shape PoolSize { get; set; } public Shape PoolSize { get; set; }


/// <summary> /// <summary>
/// specifying the strides of the pooling operation. /// specifying the strides of the pooling operation.
/// </summary> /// </summary>
[JsonProperty("strides")]
public Shape Strides { get; set; } public Shape Strides { get; set; }


/// <summary> /// <summary>
/// The padding method, either 'valid' or 'same'. /// The padding method, either 'valid' or 'same'.
/// </summary> /// </summary>
[JsonProperty("padding")]
public string Padding { get; set; } = "valid"; public string Padding { get; set; } = "valid";


/// <summary> /// <summary>
/// one of `channels_last` (default) or `channels_first`. /// one of `channels_last` (default) or `channels_first`.
/// </summary> /// </summary>
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs View File

@@ -4,7 +4,7 @@ using System.Text;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class PreprocessingLayerArgs : LayerArgs
public class PreprocessingLayerArgs : AutoSerializeLayerArgs
{ {
} }
} }

+ 12
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{
public class RescalingArgs : AutoSerializeLayerArgs
{
[JsonProperty("scale")]
public float Scale { get; set; }
[JsonProperty("offset")]
public float Offset { get; set; }
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs View File

@@ -1,5 +1,6 @@
namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: no corresponding class found in keras python, maybe obselete?
public class ResizingArgs : PreprocessingLayerArgs public class ResizingArgs : PreprocessingLayerArgs
{ {
public int Height { get; set; } public int Height { get; set; }


+ 10
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs View File

@@ -1,4 +1,5 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


@@ -6,11 +7,19 @@ namespace Tensorflow.Keras.ArgsDefinition
{ {
public class TextVectorizationArgs : PreprocessingLayerArgs public class TextVectorizationArgs : PreprocessingLayerArgs
{ {
[JsonProperty("standardize")]
public Func<Tensor, Tensor> Standardize { get; set; } public Func<Tensor, Tensor> Standardize { get; set; }
[JsonProperty("split")]
public string Split { get; set; } = "standardize"; public string Split { get; set; } = "standardize";
[JsonProperty("max_tokens")]
public int MaxTokens { get; set; } = -1; public int MaxTokens { get; set; } = -1;
[JsonProperty("output_mode")]
public string OutputMode { get; set; } = "int"; public string OutputMode { get; set; } = "int";
[JsonProperty("output_sequence_length")]
public int OutputSequenceLength { get; set; } = -1; public int OutputSequenceLength { get; set; } = -1;
[JsonProperty("vocabulary")]
public string[] Vocabulary { get; set; } public string[] Vocabulary { get; set; }

// TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding`
} }
} }

+ 7
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs View File

@@ -1,21 +1,26 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class DropoutArgs : LayerArgs
public class DropoutArgs : AutoSerializeLayerArgs
{ {
/// <summary> /// <summary>
/// Float between 0 and 1. Fraction of the input units to drop. /// Float between 0 and 1. Fraction of the input units to drop.
/// </summary> /// </summary>
[JsonProperty("rate")]
public float Rate { get; set; } public float Rate { get; set; }


/// <summary> /// <summary>
/// 1D integer tensor representing the shape of the /// 1D integer tensor representing the shape of the
/// binary dropout mask that will be multiplied with the input. /// binary dropout mask that will be multiplied with the input.
/// </summary> /// </summary>
[JsonProperty("noise_shape")]
public Shape NoiseShape { get; set; } public Shape NoiseShape { get; set; }


/// <summary> /// <summary>
/// random seed. /// random seed.
/// </summary> /// </summary>
[JsonProperty("seed")]
public int? Seed { get; set; } public int? Seed { get; set; }


public bool SupportsMasking { get; set; } public bool SupportsMasking { get; set; }


+ 0
- 8
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs View File

@@ -1,8 +0,0 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class RescalingArgs : LayerArgs
{
public float Scale { get; set; }
public float Offset { get; set; }
}
}

+ 18
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs View File

@@ -0,0 +1,18 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition.Reshaping
{
public class Cropping2DArgs : LayerArgs
{
/// <summary>
/// channel last: (b, h, w, c)
/// channels_first: (b, c, h, w)
/// </summary>
public enum DataFormat { channels_first = 0, channels_last = 1 }
/// <summary>
/// Accept: int[1][2], int[1][1], int[2][2]
/// </summary>
public NDArray cropping { get; set; }
public DataFormat data_format { get; set; } = DataFormat.channels_last;
}
}

+ 18
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs View File

@@ -0,0 +1,18 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition.Reshaping
{
public class Cropping3DArgs : LayerArgs
{
/// <summary>
/// channel last: (b, h, w, c)
/// channels_first: (b, c, h, w)
/// </summary>
public enum DataFormat { channels_first = 0, channels_last = 1 }
/// <summary>
/// Accept: int[1][3], int[1][1], int[3][2]
/// </summary>
public NDArray cropping { get; set; }
public DataFormat data_format { get; set; } = DataFormat.channels_last;
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs View File

@@ -0,0 +1,12 @@
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition.Reshaping
{
public class Cropping1DArgs : LayerArgs
{
/// <summary>
/// Accept length 1 or 2
/// </summary>
public NDArray cropping { get; set; }
}
}

+ 5
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs View File

@@ -1,7 +1,10 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class FlattenArgs : LayerArgs
public class FlattenArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
} }
} }

+ 8
- 4
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs View File

@@ -1,5 +1,9 @@
namespace Tensorflow.Keras.ArgsDefinition {
public class PermuteArgs : LayerArgs {
public int[] dims { get; set; }
}
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition {
public class PermuteArgs : AutoSerializeLayerArgs
{
[JsonProperty("dims")]
public int[] dims { get; set; }
}
} }

+ 5
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs View File

@@ -1,7 +1,10 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class ReshapeArgs : LayerArgs
public class ReshapeArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("target_shape")]
public Shape TargetShape { get; set; } public Shape TargetShape { get; set; }
public object[] TargetShapeObjects { get; set; } public object[] TargetShapeObjects { get; set; }
} }


+ 7
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs View File

@@ -1,12 +1,17 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class UpSampling2DArgs : LayerArgs
public class UpSampling2DArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("size")]
public Shape Size { get; set; } public Shape Size { get; set; }
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
/// <summary> /// <summary>
/// 'nearest', 'bilinear' /// 'nearest', 'bilinear'
/// </summary> /// </summary>
[JsonProperty("interpolation")]
public string Interpolation { get; set; } = "nearest"; public string Interpolation { get; set; } = "nearest";
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs View File

@@ -2,6 +2,7 @@


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: complete the implementation
public class ZeroPadding2DArgs : LayerArgs public class ZeroPadding2DArgs : LayerArgs
{ {
public NDArray Padding { get; set; } public NDArray Padding { get; set; }


src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -1,9 +1,8 @@
using Tensorflow.Keras.ArgsDefinition.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Lstm
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{ {
public class LSTMArgs : RNNArgs public class LSTMArgs : RNNArgs
{ {
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; } public bool UnitForgetBias { get; set; }
public float Dropout { get; set; } public float Dropout { get; set; }
public float RecurrentDropout { get; set; } public float RecurrentDropout { get; set; }

+ 7
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -0,0 +1,7 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO: complete the implementation
public class LSTMCellArgs : LayerArgs
{
}
}

+ 12
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,21 +1,30 @@
using System.Collections.Generic;
using Newtonsoft.Json;
using System.Collections.Generic;


namespace Tensorflow.Keras.ArgsDefinition.Rnn namespace Tensorflow.Keras.ArgsDefinition.Rnn
{ {
public class RNNArgs : LayerArgs
public class RNNArgs : AutoSerializeLayerArgs
{ {
public interface IRnnArgCell : ILayer public interface IRnnArgCell : ILayer
{ {
object state_size { get; } object state_size { get; }
} }

[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null; public IRnnArgCell Cell { get; set; } = null;
[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false; public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
public bool ReturnState { get; set; } = false; public bool ReturnState { get; set; } = false;
[JsonProperty("go_backwards")]
public bool GoBackwards { get; set; } = false; public bool GoBackwards { get; set; } = false;
[JsonProperty("stateful")]
public bool Stateful { get; set; } = false; public bool Stateful { get; set; } = false;
[JsonProperty("unroll")]
public bool Unroll { get; set; } = false; public bool Unroll { get; set; } = false;
[JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false; public bool TimeMajor { get; set; } = false;
// TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary<string, object> Kwargs { get; set; } = null; public Dictionary<string, object> Kwargs { get; set; } = null;


public int Units { get; set; } public int Units { get; set; }


+ 50
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs View File

@@ -0,0 +1,50 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedActivationJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Activation);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject("");
token.WriteTo(writer);
}
else if (value is not Activation)
{
throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Activation)!.GetType().Name);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
throw new NotImplementedException();
//var dims = serializer.Deserialize(reader, typeof(string));
//if (dims is null)
//{
// throw new ValueError("Cannot deserialize 'null' to `Activation`.");
//}
//return new Shape((long[])(dims!));
}
}
}

+ 48
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -0,0 +1,48 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedAxisJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Axis);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(new int[] { });
token.WriteTo(writer);
}
else if (value is not Axis)
{
throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Axis)!.axis);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var axis = serializer.Deserialize(reader, typeof(long[]));
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
}
return new Axis((int[])(axis!));
}
}
}

+ 73
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -0,0 +1,73 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Common
{
public class CustomizedNodeConfigJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(NodeConfig);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if (value is not NodeConfig)
{
throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var config = value as NodeConfig;
var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex });
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var values = serializer.Deserialize(reader, typeof(object[])) as object[];
if (values is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length != 3)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
if (values[0] is not string)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
}
if (values[1] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
}
if (values[2] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
}
return new NodeConfig()
{
Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
};
}
}
}

+ 67
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -0,0 +1,67 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedShapeJsonConverter: JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Shape);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if(value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if(value is not Shape)
{
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var shape = (value as Shape)!;
long?[] dims = new long?[shape.ndim];
for(int i = 0; i < dims.Length; i++)
{
if (shape.dims[i] == -1)
{
dims[i] = null;
}
else
{
dims[i] = shape.dims[i];
}
}
var token = JToken.FromObject(dims);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
if(dims is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{
convertedDims[i] = dims[i] ?? (-1);
}
return new Shape(convertedDims);
}
}
}

+ 30
- 1
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -16,23 +16,27 @@


using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
/// <summary> /// <summary>
/// Specifies the ndim, dtype and shape of every input to a layer. /// Specifies the ndim, dtype and shape of every input to a layer.
/// </summary> /// </summary>
public class InputSpec
public class InputSpec: IKerasConfigable
{ {
public int? ndim; public int? ndim;
public int? max_ndim;
public int? min_ndim; public int? min_ndim;
Dictionary<int, int> axes; Dictionary<int, int> axes;
Shape shape; Shape shape;
TF_DataType dtype;
public int[] AllAxisDim; public int[] AllAxisDim;


public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null, int? ndim = null,
int? min_ndim = null, int? min_ndim = null,
int? max_ndim = null,
Dictionary<int, int> axes = null, Dictionary<int, int> axes = null,
Shape shape = null) Shape shape = null)
{ {
@@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine
axes = new Dictionary<int, int>(); axes = new Dictionary<int, int>();
this.axes = axes; this.axes = axes;
this.min_ndim = min_ndim; this.min_ndim = min_ndim;
this.max_ndim = max_ndim;
this.shape = shape; this.shape = shape;
this.dtype = dtype;
if (ndim == null && shape != null) if (ndim == null && shape != null)
this.ndim = shape.ndim; this.ndim = shape.ndim;


@@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine
AllAxisDim = axes.Select(x => x.Value).ToArray(); AllAxisDim = axes.Select(x => x.Value).ToArray();
} }


public IKerasConfig get_config()
{
return new Config()
{
DType = dtype == TF_DataType.DtInvalid ? null : dtype,
Shape = shape,
Ndim = ndim,
MinNdim = min_ndim,
MaxNdim = max_ndim,
Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value)
};
}

public override string ToString() public override string ToString()
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";

public class Config: IKerasConfig
{
public TF_DataType? DType { get; set; }
public Shape Shape { get; set; }
public int? Ndim { get; set; }
public int? MinNdim { get;set; }
public int? MaxNdim { get;set; }
public IDictionary<string, int> Axes { get; set; }
}
} }
} }

+ 4
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,10 +1,12 @@
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Training;


namespace Tensorflow.Keras namespace Tensorflow.Keras
{ {
public interface ILayer
public interface ILayer: IWithTrackable, IKerasConfigable
{ {
string Name { get; } string Name { get; }
bool Trainable { get; } bool Trainable { get; }
@@ -19,8 +21,8 @@ namespace Tensorflow.Keras
List<IVariableV1> NonTrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; }
Shape OutputShape { get; } Shape OutputShape { get; }
Shape BatchInputShape { get; } Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
TF_DataType DType { get; } TF_DataType DType { get; }
int count_params(); int count_params();
LayerArgs get_config();
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs View File

@@ -1,5 +1,5 @@
using System; using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.NumPy; using Tensorflow.NumPy;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers


+ 15
- 0
src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public interface IKerasConfig
{
}

public interface IKerasConfigable
{
IKerasConfig get_config();
}
}

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs View File

@@ -1,4 +1,5 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
@@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class LayerConfig
public class LayerConfig: IKerasConfig
{ {
[JsonProperty("name")]
public string Name { get; set; } public string Name { get; set; }
[JsonProperty("class_name")]
public string ClassName { get; set; } public string ClassName { get; set; }
[JsonProperty("config")]
public LayerArgs Config { get; set; } public LayerArgs Config { get; set; }
[JsonProperty("inbound_nodes")]
public List<NodeConfig> InboundNodes { get; set; } public List<NodeConfig> InboundNodes { get; set; }
} }
} }

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,15 +1,20 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class ModelConfig
public class ModelConfig : IKerasConfig
{ {
[JsonProperty("name")]
public string Name { get; set; } public string Name { get; set; }
[JsonProperty("layers")]
public List<LayerConfig> Layers { get; set; } public List<LayerConfig> Layers { get; set; }
[JsonProperty("input_layers")]
public List<NodeConfig> InputLayers { get; set; } public List<NodeConfig> InputLayers { get; set; }
[JsonProperty("output_layers")]
public List<NodeConfig> OutputLayers { get; set; } public List<NodeConfig> OutputLayers { get; set; }


public override string ToString() public override string ToString()


+ 5
- 2
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -1,10 +1,13 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class NodeConfig
[JsonConverter(typeof(CustomizedNodeConfigJsonConverter))]
public class NodeConfig : IKerasConfig
{ {
public string Name { get; set; } public string Name { get; set; }
public int NodeIndex { get; set; } public int NodeIndex { get; set; }


+ 35
- 0
src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel
{
public interface ISerializedAttributes
{
IDictionary<string, Trackable> Functions { get; }

IDictionary<string, Trackable> CheckpointableObjects { get; }

/// <summary>
/// Returns functions to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> FunctionsToSerialize { get; }

/// <summary>
/// Returns objects to attach to the root object during serialization.
/// </summary>
IDictionary<string, Trackable> ObjectsToSerialize{get; }

/// <summary>
/// Saves function dictionary, and validates dictionary values.
/// </summary>
/// <param name="function_dict"></param>
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict);

/// <summary>
/// Saves objects to a dictionary, and validates the values.
/// </summary>
/// <param name="object_dict"></param>
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict);
}
}

+ 21
- 0
src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs View File

@@ -0,0 +1,21 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class TensorShapeConfig
{
[JsonProperty("class_name")]
public string ClassName { get; set; } = "TensorShape";
[JsonProperty("items")]
public long?[] Items { get; set; }

public static implicit operator Shape(TensorShapeConfig shape)
=> shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());

public static implicit operator TensorShapeConfig(Shape shape)
=> new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() };
}
}

+ 43
- 1
src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs View File

@@ -9,10 +9,52 @@ namespace Tensorflow.ModelSaving
/// </summary> /// </summary>
public class SaveOptions public class SaveOptions
{ {
bool save_debug_info;
public bool save_debug_info = false;
public IList<string>? namespace_white_list { get; set; } = null;
public IDictionary<string, object>? function_aliases { get; set; } = null;
public string? experimental_io_device { get; set; } = null;
// TODO: experimental
public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None;
public bool experimental_custom_gradients { get; set; } = true;
public SaveOptions(bool save_debug_info = false) public SaveOptions(bool save_debug_info = false)
{ {
this.save_debug_info = save_debug_info; this.save_debug_info = save_debug_info;
} }
} }

public class VariablePolicy
{
public string Policy { get; }
private VariablePolicy(string policy)
{
Policy = policy;
}
public static VariablePolicy None = new(null);
public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices");
public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables");

public bool save_variable_devices()
{
return this != VariablePolicy.None;
}

/// <summary>
/// Tries to convert `obj` to a VariablePolicy instance.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public static VariablePolicy from_obj(object obj)
{
if (obj is null) return VariablePolicy.None;
if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower();
return key switch
{
null => VariablePolicy.None,
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
};
}
}
} }

+ 10
- 1
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -14,20 +14,29 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;


namespace Tensorflow namespace Tensorflow
{ {
public record Axis(params int[] axis)
[JsonConverter(typeof(CustomizedAxisJsonConverter))]
public class Axis
{ {
public int[] axis { get; set; }
public int size => axis == null ? -1 : axis.Length; public int size => axis == null ? -1 : axis.Length;
public bool IsScalar { get; init; } public bool IsScalar { get; init; }


public int this[int index] => axis[index]; public int this[int index] => axis[index];


public Axis(params int[] axis)
{
this.axis = axis;
}

public static implicit operator int[]?(Axis axis) public static implicit operator int[]?(Axis axis)
=> axis?.axis; => axis?.axis;




+ 3
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -14,14 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.NumPy; using Tensorflow.NumPy;


namespace Tensorflow namespace Tensorflow
{ {
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape public class Shape
{ {
public int ndim => _dims == null ? -1 : _dims.Length; public int ndim => _dims == null ? -1 : _dims.Length;


+ 10
- 0
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Constant<T> : IInitializer public class Constant<T> : IInitializer
@@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers
T value; T value;
bool _verify_shape; bool _verify_shape;


private readonly Dictionary<string, object> _config;

public string ClassName => "Constant";
public IDictionary<string, object> Config => _config;

public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
{ {
this.value = value; this.value = value;
this.dtype = dtype; this.dtype = dtype;
_verify_shape = verify_shape; _verify_shape = verify_shape;

_config = new Dictionary<string, object>();
_config["value"] = this.value;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 9
- 1
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -14,10 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class GlorotUniform : VarianceScaling public class GlorotUniform : VarianceScaling
{ {
private readonly Dictionary<string, object> _config;

public override string ClassName => "GlorotUniform";
public override IDictionary<string, object> Config => _config;

public GlorotUniform(float scale = 1.0f, public GlorotUniform(float scale = 1.0f,
string mode = "FAN_AVG", string mode = "FAN_AVG",
bool uniform = true, bool uniform = true,
@@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers
seed: seed, seed: seed,
dtype: dtype) dtype: dtype)
{ {

_config = new Dictionary<string, object>();
_config["seed"] = _seed;
} }
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -14,10 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System.Collections.Generic;

namespace Tensorflow namespace Tensorflow
{ {
public interface IInitializer public interface IInitializer
{ {
[JsonProperty("class_name")]
string ClassName { get; }
[JsonProperty("config")]
IDictionary<string, object> Config { get; }
Tensor Apply(InitializerArgs args); Tensor Apply(InitializerArgs args);
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -14,12 +14,19 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Ones : IInitializer public class Ones : IInitializer
{ {
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "Ones";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
this.dtype = dtype; this.dtype = dtype;


+ 6
- 1
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2023 Haiping Chen. All Rights Reserved. Copyright 2023 Haiping Chen. All Rights Reserved.


Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,6 +19,7 @@ using System.Linq;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Operations.Initializers; namespace Tensorflow.Operations.Initializers;
using System.Collections.Generic;


public class Orthogonal : IInitializer public class Orthogonal : IInitializer
{ {
@@ -31,6 +32,10 @@ public class Orthogonal : IInitializer
_seed = seed; _seed = seed;
} }


private readonly Dictionary<string, object> _config;

public string ClassName => "Orthogonal";
public IDictionary<string, object> Config => throw new NotImplementedException();
public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)
{ {
return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType);


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class RandomNormal : IInitializer public class RandomNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed; private int? seed;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "RandomNormal";
public IDictionary<string, object> Config => _config;

public RandomNormal(float mean = 0.0f, public RandomNormal(float mean = 0.0f,
float stddev = 0.05f, float stddev = 0.05f,
int? seed = null, int? seed = null,
@@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev; this.stddev = stddev;
this.seed = seed; this.seed = seed;
this.dtype = dtype; this.dtype = dtype;

_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class RandomUniform : IInitializer public class RandomUniform : IInitializer
@@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers
private float maxval; private float maxval;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "RandomUniform";
public IDictionary<string, object> Config => _config;

public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null)
{ {
this.dtype = dtype; this.dtype = dtype;
this.minval = minval; this.minval = minval;
this.maxval = maxval; this.maxval = maxval;
this.seed = seed; this.seed = seed;

_config = new Dictionary<string, object>();
_config["minval"] = this.minval;
_config["maxval"] = this.maxval;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 11
- 0
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class TruncatedNormal : IInitializer public class TruncatedNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed; private int? seed;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "TruncatedNormal";
public IDictionary<string, object> Config => _config;

public TruncatedNormal(float mean = 0.0f, public TruncatedNormal(float mean = 0.0f,
float stddev = 1.0f, float stddev = 1.0f,
int? seed = null, int? seed = null,
@@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev; this.stddev = stddev;
this.seed = seed; this.seed = seed;
this.dtype = dtype; this.dtype = dtype;
_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 13
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -15,7 +15,9 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Linq.Expressions;


namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
@@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers
protected int? _seed; protected int? _seed;
protected TF_DataType _dtype; protected TF_DataType _dtype;
protected bool _uniform; protected bool _uniform;
private readonly Dictionary<string, object> _config;

public virtual string ClassName => "VarianceScaling";

public virtual IDictionary<string, object> Config => _config;


public VarianceScaling(float factor = 2.0f, public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN", string mode = "FAN_IN",
@@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers
_seed = seed; _seed = seed;
_dtype = dtype; _dtype = dtype;
_uniform = uniform; _uniform = uniform;

_config = new();
_config["scale"] = _scale;
_config["mode"] = _mode;
_config["distribution"] = _distribution;
_config["seed"] = _seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 5
- 0
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Zeros : IInitializer public class Zeros : IInitializer
@@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers
Shape shape; Shape shape;
TF_DataType dtype; TF_DataType dtype;


public string ClassName => "Zeros";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
this.shape = shape; this.shape = shape;


+ 7
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -20,7 +20,9 @@ using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Operations; using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -75,6 +77,8 @@ namespace Tensorflow


public Shape BatchInputShape => throw new NotImplementedException(); public Shape BatchInputShape => throw new NotImplementedException();


public TensorShapeConfig BuildInputShape => throw new NotImplementedException();

public TF_DataType DType => throw new NotImplementedException(); public TF_DataType DType => throw new NotImplementedException();
protected bool built = false; protected bool built = false;
public bool Built => built; public bool Built => built;
@@ -143,7 +147,7 @@ namespace Tensorflow
throw new NotImplementedException(); throw new NotImplementedException();
} }


public LayerArgs get_config()
public IKerasConfig get_config()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
@@ -152,5 +156,7 @@ namespace Tensorflow
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }

public Trackable GetTrackable() { throw new NotImplementedException(); }
} }
} }

+ 55
- 4
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -1,6 +1,9 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Xml.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Operations namespace Tensorflow.Operations
@@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations
/// path in the input checkpoint_prefixes. This is useful when those paths are non /// path in the input checkpoint_prefixes. This is useful when those paths are non
/// user-facing temporary locations. /// user-facing temporary locations.
/// </remarks> /// </remarks>
public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints")
{
public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files));
result = null;
return null;
//try
//{
// var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
// new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }));
// result = null;
// return null;
//}
//catch (System.Exception)
//{
// return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs,
// allow_missing_files: allow_missing_files, name: name, ctx: ctx);
//}
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["checkpoint_prefixes"] = checkpoint_prefixes; dict["checkpoint_prefixes"] = checkpoint_prefixes;
dict["destination_prefix"] = destination_prefix; dict["destination_prefix"] = destination_prefix;
if (delete_old_dirs.HasValue)
dict["delete_old_dirs"] = delete_old_dirs.Value;
dict["delete_old_dirs"] = delete_old_dirs;
var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict);
return op; return op;
} }


//public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx)
//{
// checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING);
// destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING);
// var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix };
// var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files };
// var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name);
// result = null;
// return null;
//}

/// <summary> /// <summary>
/// Transforms a spectrogram into a form that's useful for speech recognition. /// Transforms a spectrogram into a form that's useful for speech recognition.
/// </summary> /// </summary>
@@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations
/// </remarks> /// </remarks>
public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch")
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern));
return result[0];
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["input"] = input; dict["input"] = input;
dict["pattern"] = pattern; dict["pattern"] = pattern;
@@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations
/// </remarks> /// </remarks>
public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename")
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards));
return result[0];
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["basename"] = basename; dict["basename"] = basename;
dict["shard"] = shard; dict["shard"] = shard;
@@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations
/// </remarks> /// </remarks>
public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin")
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator));
return result[0];
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["inputs"] = inputs; dict["inputs"] = inputs;
if (separator != null) if (separator != null)


+ 32
- 0
src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -14,7 +14,9 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Linq;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -23,11 +25,41 @@ namespace Tensorflow
{ {
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(
new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors }));
result = null;
return null;
}
catch (System.Exception)
{
return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx);
}
}
var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });


return _op; return _op;
} }


public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx)
{
DataType[] attr_dtypes;
(attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx);
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING);
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING);
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING);
var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray();
var attrs = new object[] { "dtypes", attr_dtypes };

var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name);
result = null;
return null;
}

public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{ {
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


+ 60
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,6 +17,9 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.CppShapeInferenceResult.Types;


namespace Tensorflow namespace Tensorflow
@@ -38,6 +41,11 @@ namespace Tensorflow
{ {
return var is ResourceVariable; return var is ResourceVariable;
} }
public static bool is_resource_variable(Trackable var)
{
return var is BaseResourceVariable;
}


/// <summary> /// <summary>
/// Creates a variable handle with information to do shape inference. /// Creates a variable handle with information to do shape inference.
@@ -171,5 +179,57 @@ namespace Tensorflow
return HandleData.Parser.ParseFrom(handle.BufferToArray()); return HandleData.Parser.ParseFrom(handle.BufferToArray());
} }
} }

/// <summary>
/// Copies an existing variable to a new graph, with no initializer.
/// </summary>
/// <param name="variable"></param>
public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable)
{
var new_variable = new UninitializedVariable(
trainable: variable.Trainable,
shape: variable.shape,
dtype: variable.dtype,
name: variable.SharedName,
aggregation: variable.Aggregation,
extra_handle_data: null);
new_variable._maybe_initialize_trackable();
return new_variable;
}

/// <summary>
/// Writes additional information of the variable into the SavedObject proto.
/// </summary>
/// <param name="resource_variable"></param>
/// <param name="proto"></param>
/// <param name="options"></param>
/// <param name="enforcing_naming"></param>
public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true)
{
// lack of API: `proto.Variable.SetInParent()`.
if(enforcing_naming && !resource_variable.Name.EndsWith(":0"))
{
throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " +
$"unexpected suffix in the name (expected ':0') which won't be restored.");
}
if(proto.Variable is null)
{
proto.Variable = new SavedVariable();
}
proto.Variable.Name = meta_graph.op_name(resource_variable.Name);
proto.Variable.Trainable = resource_variable.Trainable;
proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum();
// TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`.
proto.Variable.Aggregation = resource_variable.Aggregation;
proto.Variable.Shape = resource_variable.shape.as_proto();

if (options.experimental_variable_policy.save_variable_devices())
{
if (!string.IsNullOrEmpty(resource_variable.Device))
{
proto.Variable.Device = resource_variable.Device;
}
}
}
} }
} }

+ 9
- 1
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -156,7 +156,7 @@ namespace Tensorflow {
/// Nodes[0] is considered the root node. /// Nodes[0] is considered the root node.
/// </summary> /// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
get { return nodes_; } get { return nodes_; }
} }


@@ -286,6 +286,7 @@ namespace Tensorflow {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(SavedObject other) : this() { public SavedObject(SavedObject other) : this() {
children_ = other.children_.Clone(); children_ = other.children_.Clone();
dependencies_ = other.dependencies_.Clone();
slotVariables_ = other.slotVariables_.Clone(); slotVariables_ = other.slotVariables_.Clone();
saveableObjects_ = other.saveableObjects_.Clone(); saveableObjects_ = other.saveableObjects_.Clone();
switch (other.KindCase) { switch (other.KindCase) {
@@ -328,6 +329,7 @@ namespace Tensorflow {
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
/// <summary> /// <summary>
/// Objects which this object depends on: named edges in the dependency /// Objects which this object depends on: named edges in the dependency
/// graph. /// graph.
@@ -338,6 +340,11 @@ namespace Tensorflow {
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children {
get { return children_; } get { return children_; }
} }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies {
get { return dependencies_; }
}


/// <summary>Field number for the "slot_variables" field.</summary> /// <summary>Field number for the "slot_variables" field.</summary>
public const int SlotVariablesFieldNumber = 3; public const int SlotVariablesFieldNumber = 3;
@@ -617,6 +624,7 @@ namespace Tensorflow {
return; return;
} }
children_.Add(other.children_); children_.Add(other.children_);
dependencies_.Add(other.dependencies_);
slotVariables_.Add(other.slotVariables_); slotVariables_.Add(other.slotVariables_);
saveableObjects_.Add(other.saveableObjects_); saveableObjects_.Add(other.saveableObjects_);
switch (other.KindCase) { switch (other.KindCase) {


+ 16
- 0
src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs View File

@@ -198,6 +198,22 @@ namespace Tensorflow {
public TrackableObject() { public TrackableObject() {
OnConstruction(); OnConstruction();
} }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) {
OnConstruction();
slotVariables_ = slot;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot,
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children
)
{
OnConstruction();
slotVariables_ = slot;
children_ = children;
}


partial void OnConstruction(); partial void OnConstruction();




+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description>


<ItemGroup> <ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="Protobuf.Text" Version="0.6.0" /> <PackageReference Include="Protobuf.Text" Version="0.6.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup> </ItemGroup>


+ 18
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -202,6 +202,24 @@ namespace Tensorflow
_ => type.ToString() _ => type.ToString()
}; };


public static string as_python_name(this TF_DataType type)
=> type switch
{
TF_DataType.TF_STRING => "str",
TF_DataType.TF_UINT8 => "uint8",
TF_DataType.TF_INT8 => "int8",
TF_DataType.TF_UINT32 => "uint32",
TF_DataType.TF_INT32 => "int32",
TF_DataType.TF_UINT64 => "uint64",
TF_DataType.TF_INT64 => "int64",
TF_DataType.TF_FLOAT => "float32",
TF_DataType.TF_DOUBLE => "float64",
TF_DataType.TF_BOOL => "bool",
TF_DataType.TF_RESOURCE => "resource",
TF_DataType.TF_VARIANT => "variant",
_ => type.ToString()
};

public static int get_datatype_size(this TF_DataType type) public static int get_datatype_size(this TF_DataType type)
=> type.as_base_dtype() switch => type.as_base_dtype() switch
{ {


+ 67
- 2
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -1,6 +1,71 @@
namespace Tensorflow.Train
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{ {
public abstract class AutoTrackable : Trackable
public class AutoTrackable : Trackable
{ {
public void _delete_tracking(string name)
{
_maybe_initialize_trackable();
if (_unconditional_dependency_names.ContainsKey(name))
{
_unconditional_dependency_names.Remove(name);
for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--)
{
if (_unconditional_checkpoint_dependencies[i].Name == name)
{
_unconditional_checkpoint_dependencies.RemoveAt(i);
}
}
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type != SaveType.SAVEDMODEL)
{
return base._trackable_children(save_type, cache);
}

Dictionary<string, Trackable> functions = new();
// TODO: process of logs.
var properties = this.GetType().GetProperties();
foreach ( var property in properties )
{
if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction))
{
string name = property.Name;
object value = property.GetValue(this, null);
functions[name] = (Trackable)value;
}
}

// TODO: process the type `core_types.GenericFunction`.

Dictionary<string, Trackable> children = new();
foreach(var pair in CheckpointDependencies)
{
var name = pair.Name;
var child = pair.Refer;
if(child is ConcreteFunction) // or Generic function
{
continue;
}
if(functions.ContainsKey(name) && functions[name] != child)
{
throw new ValueError($"Can't save object because it has multiple children with the same " +
$"name. Object: {this}, attribute name: {name}, child 1: " +
$"{child}, child 2: {functions[name]}");
}
children[name] = child;
}

return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value);
}
} }
} }

+ 12
- 0
src/TensorFlowNET.Core/Training/IWithTrackable.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{
public interface IWithTrackable
{
Trackable GetTrackable();
}
}

+ 9
- 0
src/TensorFlowNET.Core/Training/LayerUtils.cs View File

@@ -0,0 +1,9 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{

}

+ 6
- 1
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -351,7 +351,7 @@ namespace Tensorflow
/// <param name="var"></param> /// <param name="var"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
protected IVariableV1 get_slot(IVariableV1 var, string name)
internal IVariableV1 get_slot(IVariableV1 var, string name)
{ {
var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; var named_slots = _slots.ContainsKey(name) ? _slots[name] : null;
if (named_slots == null) if (named_slots == null)
@@ -360,6 +360,11 @@ namespace Tensorflow
return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null;
} }


internal IEnumerable<string> get_slot_names()
{
return _slots.Keys;
}

private string _var_key(IVariableV1 var) private string _var_key(IVariableV1 var)
{ {
return $"{var.Op.graph.graph_key}.{var.Op.name}"; return $"{var.Op.graph.graph_key}.{var.Op.name}";


+ 28
- 0
src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using static Tensorflow.Binding;

namespace Tensorflow namespace Tensorflow
{ {
public class ResourceVariableSaveable : MySaveableObject public class ResourceVariableSaveable : MySaveableObject
@@ -35,6 +37,32 @@ namespace Tensorflow
this.name = name; this.name = name;
} }


public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name)
{
_var_device = var.Device;
_var_shape = var.shape;

Tensor _read_variable_closure(BaseResourceVariable v)
{
tf.device(v.Device);
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
}

this.handle_op = var.Handle;
var tensor = _read_variable_closure(var);

var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype);
_op = var;
specs = new SaveSpec[] { spec };
this.name = name;
}

public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{ {
var restored_tensor = restored_tensors[0]; var restored_tensor = restored_tensors[0];


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
public string slice_spec => _slice_spec; public string slice_spec => _slice_spec;


private string _name; private string _name;
public string name => _name;
public string name { get => _name; set => _name = value; }


private TF_DataType _dtype; private TF_DataType _dtype;
public TF_DataType dtype => _dtype; public TF_DataType dtype => _dtype;


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save