Browse Source

Support loading weights for customized layer.

tags/v0.100.5-BERT-load
AsakusaRinne 2 years ago
parent
commit
3943375b67
30 changed files with 942 additions and 198 deletions
  1. +13
    -12
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  2. +6
    -4
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  3. +15
    -16
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  5. +5
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  6. +5
    -2
      src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  8. +14
    -1
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  9. +65
    -31
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  10. +6
    -2
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  11. +19
    -0
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  12. +14
    -10
      src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs
  13. +44
    -3
      src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs
  14. +29
    -4
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  15. +17
    -10
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  16. +48
    -35
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  17. +67
    -1
      src/TensorFlowNET.Core/Training/Trackable.cs
  18. +344
    -27
      src/TensorFlowNET.Core/Training/data_structures.cs
  19. +8
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  21. +8
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  22. +57
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  23. +51
    -2
      src/TensorFlowNET.Keras/Engine/Model.cs
  24. +34
    -17
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  25. +1
    -13
      src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs
  26. +10
    -5
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  27. +8
    -0
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
  28. +22
    -0
      src/TensorFlowNET.Keras/Utils/compile_utils.cs
  29. +25
    -0
      src/TensorFlowNET.Keras/Utils/tf_utils.cs
  30. +3
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 13
- 12
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Checkpoint
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -119,16 +119,16 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null;
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -150,12 +150,12 @@ namespace Tensorflow.Checkpoint
return serialized_tensors;
}

private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
@@ -165,8 +165,7 @@ namespace Tensorflow.Checkpoint
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -175,10 +174,12 @@ namespace Tensorflow.Checkpoint

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
foreach(var key in maybe_tensor.Keys)
{
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>())
{
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name;
}
}

if(object_graph_proto is not null)
@@ -202,7 +203,7 @@ namespace Tensorflow.Checkpoint
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();


+ 6
- 4
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -45,12 +45,12 @@ public class TrackableSaver
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -69,9 +69,10 @@ public class TrackableSaver
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (!serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>();
serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>();
}
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>();
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor);
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

@@ -387,6 +388,7 @@ public class CheckpointRestoreCoordinator
/// </summary>
public List<Trackable> AllTrackables => _all_trackables;
public HashSet<int> MatchedProtoIds => _matched_proto_ids;
// TODO(Rinne): change to weak ref.
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id;
public int RestoreUid => _restore_uid;
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto;


+ 15
- 16
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -160,12 +160,12 @@ namespace Tensorflow.Checkpoint
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param>
/// <param name="registered_savers"></param>
/// <param name="call_with_mapped_capture"></param>
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors,
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors,
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false)
{
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>();
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>();
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new();
Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new();
foreach(var pair in serialized_tensors)
{
@@ -191,16 +191,7 @@ namespace Tensorflow.Checkpoint
foreach(var item in tensor_dict)
{
var checkpoint_key = item.Key;
IDictionary<string, Tensor> spec_to_tensor;
if(item.Value.TryPickT0(out var t, out var dic))
{
spec_to_tensor = new Dictionary<string, Tensor>();
spec_to_tensor[""] = t;
}
else
{
spec_to_tensor = dic;
}
var spec_to_tensor = item.Value;

foreach(var spec in spec_to_tensor)
{
@@ -216,11 +207,19 @@ namespace Tensorflow.Checkpoint
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
var host_device = tensor.Device;
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>());
string host_device;
if (tensor.IsT0)
{
host_device = tensor.AsT0.Device;
}
else
{
host_device = tensor.AsT1.device;
}
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>();
internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>();
}
internal_dict[checkpoint_key][slice_spec] = tensor;
}
@@ -425,7 +424,7 @@ namespace Tensorflow.Checkpoint

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
foreach (var saveable in saveables)
{
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });


+ 2
- 1
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -50,7 +51,7 @@ public class CheckpointPosition
{
_checkpoint.AllTrackables.Add(trackable);
_checkpoint.MatchedProtoIds.Add(_proto_id);
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null)
{
// skip the `logging.warning`.
return false;


+ 5
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -120,6 +120,11 @@ namespace Tensorflow.Contexts
name :
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";

public string anonymous_name()
{
return "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
}

public void graph_mode(bool isFunc = false)
=> context_switches.Push(false, isFunc);



+ 5
- 2
src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs View File

@@ -6,8 +6,11 @@
public class DenseSpec : TypeSpec
{
protected Shape _shape;
public Shape shape => _shape;

public Shape shape
{
get { return _shape; }
set { _shape = value; }
}
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;



+ 1
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -311,7 +311,7 @@ namespace Tensorflow
/// <param name="types">const TF_DataType*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output,
public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output,
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types,
SafeStatusHandle status);



+ 14
- 1
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -30,6 +30,18 @@ namespace Tensorflow.Operations
}
}

public static HandleData create_handle_data(Shape shape, TF_DataType dtype)
{
HandleData handle_data = new();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType()
{
Shape = shape.as_proto(),
Dtype = dtype.as_datatype_enum()
});
return handle_data;
}

public static void set_handle_data(Tensor target_t, HandleData handle_data)
{
if(target_t is EagerTensor)
@@ -37,7 +49,8 @@ namespace Tensorflow.Operations
target_t.HandleData = handle_data;
return;
}
c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.)
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
}
}
}

+ 65
- 31
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -21,6 +21,9 @@ using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;
using static Tensorflow.Binding;
using Tensorflow.Operations;
using System.Buffers;

namespace Tensorflow
{
@@ -31,6 +34,7 @@ namespace Tensorflow
{
public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null)
{
// TODO(Rinne): deal with `_handle_graph`.
var value_tensor = ops.convert_to_tensor(value);
return gen_resource_variable_ops.assign_variable_op(handle,
value_tensor,
@@ -78,6 +82,18 @@ namespace Tensorflow
string shared_name, string name, bool graph_mode, Tensor initial_value = null)
{
var container = ops.get_default_graph().Container;
if(container is null)
{
container = "";
}
if (!graph_mode)
{
if(shared_name is not null)
{
throw new Exception("Using an explicit shared_name is not allowed when executing eagerly.");
}
shared_name = tf.Context.anonymous_name();
}
var handle = gen_resource_variable_ops.var_handle_op(shape: shape,
dtype: dtype,
shared_name: shared_name,
@@ -95,26 +111,20 @@ namespace Tensorflow
}
else
{
// We do not want two distinct ResourceVariable objects for the same
// underlying resource in the runtime.
// When in eager mode, explicitly ensure so here. When in graph mode, it's
// ensured by always generating different variable names.
var exists = gen_resource_variable_ops.var_is_initialized_op(handle);

// We create an assert Op instead of checking right away in order to be
// compatible with ASYNC execution mode. Further, since not all devices
// support string tensors, we encode the assertion string in the Op name
/*gen_logging_ops.assert(gen_math_ops.logical_not(exists),
new[] { exists },
name: "EagerVariableNameReuse");*/

var handle_data = new HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType
var handle_data = handle_data_util.create_handle_data(shape, dtype);
if (initial_value is not null && initial_value.dtype == dtypes.variant)
{
Dtype = dtype.as_datatype_enum(),
Shape = shape.as_proto()
});
var extra_handle_data = get_eager_safe_handle_data(initial_value);
if (extra_handle_data is not null && extra_handle_data.IsSet)
{
if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1)
{
throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " +
$"but saw: '{handle_data}'");
}
handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType);
}
}
_set_handle_shapes_and_types(handle, handle_data, graph_mode);
return handle;
}
@@ -126,24 +136,48 @@ namespace Tensorflow
/// <param name="handle"></param>
/// <param name="handle_data"></param>
/// <param name="graph_mode"></param>
internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
{
tensor.HandleData = handle_data;
if (!graph_mode)
return;

var size = handle_data.ShapeAndType.Count;
//var shapes = handle_data.ShapeAndType.Select(x => x.Shape);
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray();
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray();
//var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s =>
//{
// if (!s.UnknownRank)
// {
// return s.Dim.Select(d => (int)d.Size).ToArray();
// }
// else
// {
// return Memory<int>.Empty;
// }
//}).ToArray();

var shapes = new IntPtr[size];
var types = new DataType[size];
var ranks = new int[size];
//List<MemoryHandle> handles = new();
//IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length];
//foreach(var (i, m) in enumerate(converted_shapes))
//{
// if(m.IsEmpty)
// {
// shapes_with_ptr[i] = IntPtr.Zero;
// }
// else
// {
// var handle = m.Pin();
// handles.Add(handle);
// shapes_with_ptr[i] = new IntPtr(handle.Pointer);
// }
//}

for (int i = 0; i < size; i++)
{
var shapeAndType = handle_data.ShapeAndType[i];
types[i] = shapeAndType.Dtype;
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count;
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray();
}
//Status status = new();
//// TODO(Rinne): enable it.
//c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(),
// shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status);
//handles = null;
}

/// <summary>


+ 6
- 2
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -330,7 +330,7 @@ namespace Tensorflow {
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
= pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
/// <summary>
@@ -698,9 +698,13 @@ namespace Tensorflow {
break;
case 10: {
children_.AddEntriesFrom(input, _repeated_children_codec);
dependencies_.AddRange(children_.Except(dependencies_));
break;
}
case 122:
{
dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec);
break;
}
case 26: {
slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec);
break;


+ 19
- 0
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -3,6 +3,7 @@ using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Operations.Activation;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow.Train
@@ -25,6 +26,13 @@ namespace Tensorflow.Train
}
}

public override void SetAttr(string name, object value)
{
// TODO(Rinne): deal with `self_setattr_tracking`.
value = TrackableDataStructure.sticky_attribute_assignment(this, name, value);
base.SetAttr(name, value);
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
if(save_type != SaveType.SAVEDMODEL)
@@ -34,6 +42,7 @@ namespace Tensorflow.Train

Dictionary<string, Trackable> functions = new();
// TODO: process of logs.
// TODO(Rinne): deal with members.
var properties = this.GetType().GetProperties();
foreach ( var property in properties )
{
@@ -45,6 +54,16 @@ namespace Tensorflow.Train
}
}

foreach(var item in CustomizedFields)
{
var name = item.Key;
var value = item.Value;
if (value is Function or ConcreteFunction)
{
functions[name] = (Trackable)value;
}
}

// TODO: process the type `core_types.GenericFunction`.

Dictionary<string, Trackable> children = new();


+ 14
- 10
src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs View File

@@ -42,22 +42,25 @@ namespace Tensorflow
_var_device = var.Device;
_var_shape = var.shape;

Tensor _read_variable_closure(BaseResourceVariable v)
Func<Tensor> _read_variable_closure(BaseResourceVariable v)
{
tf.device(v.Device);
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
return () =>
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
tf.device(v.Device);
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
};
}

this.handle_op = var.Handle;
var tensor = _read_variable_closure(var);
var tensor_creator = _read_variable_closure(var);

var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype);
var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device);
_op = var;
specs = new SaveSpec[] { spec };
this.name = name;
@@ -66,6 +69,7 @@ namespace Tensorflow
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];
tf.device(_var_device);
restored_tensor = array_ops.identity(restored_tensor);
return resource_variable_ops.shape_safe_assign_variable_handle(
handle_op, _var_shape, restored_tensor);


+ 44
- 3
src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Exceptions;

namespace Tensorflow
{
/// <summary>
@@ -21,8 +23,24 @@ namespace Tensorflow
/// </summary>
public class SaveSpec
{
private Tensor _tensor;
public Tensor tensor => _tensor;
private Tensor _tensor = null;
private Func<Tensor> _tensor_creator = null;
public Tensor tensor
{
get
{
if(_tensor is not null || _tensor_creator is null)
{
return _tensor;
}
else
{
return _tensor_creator();
}
}
}

internal Func<Tensor> TensorCreator => _tensor_creator;

private string _slice_spec;
public string slice_spec => _slice_spec;
@@ -32,13 +50,36 @@ namespace Tensorflow

private TF_DataType _dtype;
public TF_DataType dtype => _dtype;
private string _device;
public string device => _device;

public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid)
public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null)
{
_tensor = tensor;
_slice_spec = slice_spec;
_name = name;
_dtype = dtype;
if(device is not null)
{
_device = device;
}
else
{
_device = tensor.Device;
}
}

public SaveSpec(Func<Tensor> tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null)
{
_tensor_creator = tensor_creator;
_slice_spec = slice_spec;
_name = name;
if(dtype == TF_DataType.DtInvalid || device is null)
{
throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided.");
}
_dtype = dtype;
_device = device;
}
}
}

+ 29
- 4
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -1,10 +1,20 @@
using System;
using System.Diagnostics;
using Tensorflow.Train;
using Tensorflow.Training;

namespace Tensorflow;

public class RevivedTypes
{
private static Dictionary<string, ITrackableWrapper> _registered_revived_creator = new();
static RevivedTypes()
{
var list_wrapper = new ListWrapper(new Trackable[] { });
_registered_revived_creator[list_wrapper.Identifier] = list_wrapper;
var dict_wrapper = new DictWrapper(new Dictionary<object, Trackable>());
_registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper;
}
/// <summary>
/// Create a SavedUserObject from a trackable object.
/// </summary>
@@ -12,13 +22,28 @@ public class RevivedTypes
/// <returns></returns>
public static SavedUserObject? serialize(Trackable obj)
{
// TODO: complete the implementation.
// TODO(Rinne): complete the implementation.
return null;
}

public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto)
public static (Trackable, Action<object, object, object>) deserialize(SavedUserObject proto)
{
// TODO: complete the implementation.
return null;
if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper))
{
return (wrapper.FromProto(proto), (x, y, z) =>
{
if (x is not ITrackableWrapper trackable)
{
throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}.");
}
Debug.Assert(y is string);
trackable.SetValue(y, z);
}
);
}
else
{
return (null, null);
}
}
}

+ 17
- 10
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -49,6 +49,7 @@ namespace Tensorflow
var temp = _proto.ToString();
_export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
// TODO(Rinne): This method is very slow, needs to be accelareted.
_concrete_functions = function_deserialization.load_function_def_library(
meta_graph.GraphDef.Library, _proto);
_restored_concrete_functions = new HashSet<string>();
@@ -523,7 +524,7 @@ namespace Tensorflow
continue;
}
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// skip the process of "__call__"
// TODO(Rinne): deal with "__call__"
}
}

@@ -595,13 +596,12 @@ namespace Tensorflow
private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id)
{
// skip the check of proto identifier because of lack of property.

var looked_up = RevivedTypes.deserialize(proto);
if(looked_up is null)
var (trackable, setter) = RevivedTypes.deserialize(proto);
if(trackable is null)
{
return _recreate_base_user_object(proto, node_id);
}
return (looked_up.Item1, looked_up.Item2);
return (trackable, setter);
}

private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null)
@@ -668,13 +668,20 @@ namespace Tensorflow
public static Action<object, object, object> setattr = (x, y, z) =>
{
Debug.Assert(y is string);
var properties = x.GetType().GetProperties();
foreach(var p in properties)
if(x is Trackable trackable)
{
trackable.SetAttr(y as string, z);
}
else
{
if((string)y == p.Name)
var properties = x.GetType().GetProperties();
foreach (var p in properties)
{
p.SetValue(x, z);
return;
if ((string)y == p.Name)
{
p.SetValue(x, z);
return;
}
}
}
// TODO(Rinne): check if the property has been set successfully.


+ 48
- 35
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -50,6 +50,10 @@ namespace Tensorflow
}
public static class saveable_object_util
{
public static string NO_SLICE_SPEC_KEY = "";
private static HashSet<string> _VARIABLE_OPS = new HashSet<string>(new string[] {
"Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp"
});
/// <summary>
/// Returns the variables and names that will be used for a Saver.
/// </summary>
@@ -123,19 +127,12 @@ namespace Tensorflow
/// <returns></returns>
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, string name)
{
if (false)
{
}
ops.init_scope();
var variable = ops.convert_to_tensor(op, as_ref: true);
if (variable.dtype.is_ref_dtype())
yield return new ReferenceVariableSaveable(variable, "", name);
else
{
ops.init_scope();
var variable = ops.convert_to_tensor(op, as_ref: true);
if (variable.dtype.is_ref_dtype())
yield return new ReferenceVariableSaveable(variable, "", name);
else
yield return new ResourceVariableSaveable(variable, "", name);
}
yield return new ResourceVariableSaveable(variable, "", name);
}

/// <summary>
@@ -159,7 +156,7 @@ namespace Tensorflow
yield return new ResourceVariableSaveable(variable, "", name);
}
}
else
else if(obj is not IVariableV1)
{
foreach(var pair in saveable_objects_from_trackable(obj))
{
@@ -191,6 +188,30 @@ namespace Tensorflow
}
}
}
else
{
// Variable
if (tf.Context.executing_eagerly())
{
throw new ValueError($"Can only save/restore ResourceVariables when " +
$"executing eagerly, got type: {obj.GetType()}.");
}
var variable = ops.convert_to_tensor(obj, as_ref: true);
if (!_tensor_comes_from_variable(variable))
{
throw new TypeError($"names_to_saveables must be a dict mapping string " +
$"names to Tensors/Variables. Not a variable: {variable}");
}
if(variable.op.type == "Variable" || variable.op.type == "VariableV2" ||
variable.op.type == "AutoReloadVariable")
{
yield return new ReferenceVariableSaveable(variable, "", name);
}
else
{
yield return new ResourceVariableSaveable(variable, "", name);
}
}
}

/// <summary>
@@ -267,24 +288,14 @@ namespace Tensorflow
foreach (var pair in tensor_dict)
{
var tensor_name = pair.Key;
var maybe_tensor = pair.Value;
var internal_dict = pair.Value;
local_names.Add(tensor_name);
string spec_name = name + TrackableUtils.escape_local_name(tensor_name);

IDictionary<string, Tensor> internal_dict;
if (maybe_tensor.TryPickT0(out var tensor, out var dic))
{
internal_dict = new Dictionary<string, Tensor>();
internal_dict[""] = tensor;
}
else
{
internal_dict = dic;
}

foreach (var item in internal_dict)
{
specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
Debug.Assert(item.Value.IsT0);
specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name));
}
}
return new TrackableSaveable(obj, specs, name, local_names, prefix);
@@ -316,9 +327,9 @@ namespace Tensorflow
/// Converts a list of SaveableObjects to a tensor dictionary.
/// </summary>
/// <param name="saveables"></param>
public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
public static Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables)
{
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
foreach (var saveable in saveables)
{
foreach (var spec in saveable.specs)
@@ -326,14 +337,11 @@ namespace Tensorflow
// skip the check that if `spec` is callable.
var name = convert_to_string(spec.name);
var slice_spec = convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec))
{
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor;
}
else
if (string.IsNullOrEmpty(slice_spec))
{
tensor_dict[name] = spec.tensor;
slice_spec = NO_SLICE_SPEC_KEY;
}
tensor_dict.SetDefault(name, new Dictionary<string, OneOf<Tensor, SaveSpec>>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec;
}
}
return tensor_dict;
@@ -397,6 +405,11 @@ namespace Tensorflow
{
return factory(key);
}

private static bool _tensor_comes_from_variable(object v)
{
return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type);
}
}

public class SaveableCompatibilityConverter: Trackable
@@ -412,7 +425,7 @@ namespace Tensorflow
public object Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
public override IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors()
{
return saveable_object_util.saveable_object_to_tensor_dict(_saveables);
}


+ 67
- 1
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -85,6 +85,72 @@ namespace Tensorflow.Train
_self_saveable_object_factories = value;
}
}
public Dictionary<string, object> CustomizedFields { get; set; } = new Dictionary<string, object>();

public virtual void SetAttr(string name, object value)
{
var t = this.GetType();
var field_info = t.GetField(name);
if(field_info is not null)
{
field_info.SetValue(this, value);
}
else
{
CustomizedFields[name] = value;
}

// On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`.
// When adding new members or properties to this class, please add corresponding process to this method.
//switch (name)
//{
// case "_manual_tracking":
// {
// _manual_tracking = (bool)value;
// break;
// }
// case "_self_saveable_object_factories":
// {
// _self_saveable_object_factories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value;
// break;
// }
// case "_self_update_uid":
// {
// _self_update_uid = (int)value;
// break;
// }
// case "_unconditional_checkpoint_dependencies":
// {
// _unconditional_checkpoint_dependencies = (IList<TrackableReference>)value;
// break;
// }
// case "_unconditional_deferred_dependencies":
// {
// _unconditional_deferred_dependencies = (Dictionary<string, IList<CheckpointPosition>>)value;
// break;
// }
// case "_unconditional_dependency_names":
// {
// _unconditional_dependency_names = (IDictionary<string, Trackable>)value;
// break;
// }
// case "SelfSaveableObjectFactories":
// {
// SelfSaveableObjectFactories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value;
// break;
// }
// case "UpdateUid":
// {
// UpdateUid = (int)value;
// break;
// }
// default:
// {
// CustomizedAttributes[name] = value;
// break;
// }
// }
}

/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
@@ -279,7 +345,7 @@ namespace Tensorflow.Train
/// </summary>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
public virtual IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors()
{
throw new NotImplementedException();
}


+ 344
- 27
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -2,6 +2,8 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO.Compression;
using System.Linq;
using System.Linq.Expressions;
@@ -25,6 +27,48 @@ namespace Tensorflow.Training
}
}

static class TrackableWrapperUtils
{
internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto)
{
if (proto.Identifier != wrapper.Identifier)
{
return false;
}
if (wrapper.Version < proto.Version.MinConsumer)
{
return false;
}
if (proto.Version.Producer < wrapper.MinProducerVersion)
{
return false;
}
foreach (var bad_version in proto.Version.BadConsumers)
{
if (bad_version == wrapper.Version)
{
return false;
}
}
return true;
}

internal static bool is_function(Trackable x)
{
return x is Function or ConcreteFunction;
}
}

public interface ITrackableWrapper
{
void SetValue(object name, object value);
String Identifier { get; }
int Version { get; }
int MinConsumerVersion { get; }
int MinProducerVersion { get; }
Trackable FromProto(SavedUserObject proto);
}

public abstract class TrackableDataStructure : Trackable
{
private bool _self_trainable;
@@ -36,7 +80,7 @@ namespace Tensorflow.Training
_self_extra_variables = new List<IVariableV1>();
}

public abstract IEnumerable<Trackable> Values { get; }
public abstract ICollection<Trackable> Values { get; }
public bool Trainable { get => _self_trainable; set => _self_trainable = value; }
public IEnumerable<ILayer> Layers
{
@@ -134,7 +178,7 @@ namespace Tensorflow.Training
/// <param name="name"></param>
protected virtual Trackable _track_value(Trackable value, string name)
{
value = sticky_attribute_assignment(this, name, value);
value = (Trackable)sticky_attribute_assignment(this, name, value);
if(value is IVariableV1)
{
_self_extra_variables.Add(value as IVariableV1);
@@ -148,44 +192,273 @@ namespace Tensorflow.Training
return value.Value;
}

public static Trackable wrap_or_unwrap(Trackable value)
public static object wrap_or_unwrap(object value)
{
if(value is NoDependency dependency)
{
return dependency.Value;
}
if(value is Trackable trackable)
{
return trackable;
}
else if(value is IDictionary<object, Trackable> obj_dict)
{
return new DictWrapper(obj_dict);
}
else if(value is IList<Trackable> list)
{
return new ListWrapper(list);
}
else
{
return value;
}
}

public static object sticky_attribute_assignment(Trackable trackable, string name, object value)
{
bool add_dependency = value is not NoDependency;
value = wrap_or_unwrap(value);
if (!add_dependency)
{
return value;
}
if(value is Trackable trackable_obj)
{
trackable._track_trackable(trackable_obj, name, true);
}
return value;
}
}
// TODO(Rinne): Add Dict wrapper and Tuple wrapper

public class DictWrapper : TrackableDataStructure, IDictionary<object, Trackable>, ICloneable, ITrackableWrapper
{
private IDictionary<object, Trackable> _storage;
private bool _non_string_key;
private bool _external_modification;
private IDictionary<object, Trackable> _last_wrapped_dict_snapshot;

public DictWrapper(IDictionary<object, Trackable> wrapped_dict = null)
{
if(wrapped_dict is not null)
{
_storage = new Dictionary<object, Trackable>(wrapped_dict);
}
else
{
_storage = new Dictionary<object, Trackable>();
}
_update_snapshot();
}

public static Trackable wrap_or_unwrap(IList<Trackable> value)
public void SetValue(object name, object value)
{
return new ListWrapper(value);
Debug.Assert(value is Trackable);
this[name] = value as Trackable;
}
public String Identifier => "trackable_dict_wrapper";
public int Version => 1;
public int MinConsumerVersion => 1;
public int MinProducerVersion => 1;
public Trackable FromProto(SavedUserObject proto)
{
return new DictWrapper(new Dictionary<object, Trackable>());
}

public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value)
public Trackable this[object key]
{
return new ListWrapper(value.ToList());
get
{
return _storage[key];
}
set
{
_check_self_external_modification();
_maybe_initialize_trackable();
bool no_dep = value is NoDependency;
if(key is string)
{
value = _track_value(value, key);
}
else
{
value = (Trackable)wrap_or_unwrap(value);
if(!no_dep && value is Trackable)
{
_non_string_key = true;
}
}
_storage[key] = value;
_update_snapshot();
}
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value)
public ICollection<object> Keys => _storage.Keys;

public override ICollection<Trackable> Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray();

public void Add(object key, Trackable value)
{
value = wrap_or_unwrap(value);
trackable._track_trackable(value, name, true);
return value;
_storage[key] = value;
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value)
public bool ContainsKey(object key)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
return _storage.ContainsKey(key);
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value)
public bool Remove(object key)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
_check_self_external_modification();
var res = _storage.Remove(key);
_update_snapshot();
return res;
}
}

public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable
public bool TryGetValue(object key, out Trackable value)
{
return _storage.TryGetValue(key, out value);
}

public int Count => _storage.Count;

public bool IsReadOnly => _storage.IsReadOnly;

public void Add(KeyValuePair<object, Trackable> item)
{
Add(item.Key, item.Value);
}

public void Clear()
{
_storage.Clear();
_update_snapshot();
}

public bool Contains(KeyValuePair<object, Trackable> item)
{
return _storage.Contains(item);
}

public void CopyTo(KeyValuePair<object, Trackable>[] array, int arrayIndex)
{
_storage.CopyTo(array, arrayIndex);
}

public bool Remove(KeyValuePair<object, Trackable> item)
{
_check_self_external_modification();
var res = Remove(item);
_update_snapshot();
return res;
}

public IEnumerator<KeyValuePair<object, Trackable>> GetEnumerator()
{
return _storage.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator();

public object Clone()
{
var copied = new DictWrapper(_storage);
copied._external_modification = _external_modification;
copied._non_string_key = _non_string_key;
return copied;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{
_check_self_external_modification();
if (_non_string_key)
{
throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" +
$"automatically on attribute assignment). The wrapped dictionary " +
$"contains a non-string key which maps to a trackable object or " +
$"mutable data structure.\n\nIf you don't need this dictionary " +
$"checkpointed, wrap it in a non-trackable " +
$"object; it will be subsequently ignored.");
}
if (_external_modification)
{
throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " +
$"automatically on attribute assignment). The wrapped dictionary was " +
$"modified outside the wrapper (its final value was {this}, its value" +
$" when a checkpoint dependency was added was " +
$"{this._last_wrapped_dict_snapshot}), which breaks " +
$"restoration on object creation.\n\nIf you don't need this " +
$"dictionary checkpointed, wrap it in a " +
$"non-trackable object; it will be subsequently ignored.");
}
Debug.Assert(!Dirty);
var children = base._trackable_children(save_type, cache);

if(save_type == SaveType.SAVEDMODEL)
{
foreach(var item in _storage)
{
var key = item.Key;
var value = item.Value;
if (TrackableWrapperUtils.is_function(value))
{
Debug.Assert(key is string);
children[key as string] = value;
}
}
}

return children;
}

protected Trackable _track_value(Trackable value, object name)
{
bool string_key = name is string;
if (!string_key)
{
name = "-non_string_key";
}
try
{
bool no_dependency = value is NoDependency;
value = base._track_value(value, name as string);
if(!(string_key || no_dependency))
{
_non_string_key = true;
}
return value;
}
catch (ValueError)
{
return (Trackable)sticky_attribute_assignment(this, name as string, value);
}
}

private bool Dirty => _external_modification || _non_string_key;

private void _check_self_external_modification()
{
if (Dirty)
{
return;
}
if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot))
{
_external_modification = true;
_last_wrapped_dict_snapshot = null;
}
}

private void _update_snapshot()
{
// TODO(Rinne): deal with attribute_sentinel.
if (Dirty) return;
_last_wrapped_dict_snapshot = new Dictionary<object, Trackable>(_storage);
}
}
public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable, ITrackableWrapper
{
private IList<Trackable> _storage;
private bool _non_append_mutation_value;
@@ -198,11 +471,51 @@ namespace Tensorflow.Training
/// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param>
public ListWrapper(IList<Trackable> wrapped_list)
{
_storage = wrapped_list;
_storage = new List<Trackable>(wrapped_list);
_non_append_mutation_value = _external_modification_value = false;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

public string Identifier => "trackable_list_wrapper";
public int Version => 1;
public int MinConsumerVersion => 1;
public int MinProducerVersion => 1;
public Trackable FromProto(SavedUserObject proto)
{
if(TrackableWrapperUtils.ShouldLoad(this, proto))
{
return new ListWrapper(new Trackable[] { });
}
else
{
return null;
}
}
public void SetValue(object name, object value)
{
Debug.Assert(name is string);
if(int.TryParse(name as string, out var index))
{
if(value is not Trackable trackable)
{
throw new TypeError("Cannot set an object which is not trackable to ListWrapper.");
}
if(Count <= index)
{
Add(trackable);
}
else
{
this[index] = trackable;
}
}
else
{
throw new NotImplementedException("Encounter an unexpected behavior in <ListWrapper.SetAttr>, please " +
"submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}

protected bool NonAppendMuation {
get => _non_append_mutation_value;
set
@@ -222,7 +535,7 @@ namespace Tensorflow.Training
}
}

public override IEnumerable<Trackable> Values => this;
public override ICollection<Trackable> Values => this;
public bool IsReadOnly { get => _storage.IsReadOnly; }

/// <summary>
@@ -239,7 +552,7 @@ namespace Tensorflow.Training

private void update_snapshot()
{
// TODO: deal with `attribute_sentinel`.
// TODO(Rinne): deal with `attribute_sentinel`.
if (_external_modification_value || _non_append_mutation_value) return;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}
@@ -286,9 +599,9 @@ namespace Tensorflow.Training
{
base._track_value(value, name);
}
catch(ValueError ex)
catch(ValueError)
{
value = sticky_attribute_assignment(this, name, value);
value = (Trackable)sticky_attribute_assignment(this, name, value);
}
return value;
}
@@ -343,7 +656,11 @@ namespace Tensorflow.Training
update_snapshot();
}

public void Clear() => _storage.Clear();
public void Clear()
{
_storage.Clear();
update_snapshot();
}

public bool Contains(Trackable item) => _storage.Contains(item);



+ 8
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -519,6 +519,14 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}

public static T2 map_structure<T1, T2>(Func<T1, T2> func, T1 structure) where T2: class
{
var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x);

return pack_sequence_as(structure, mapped_flat_structure) as T2;
}

/// <summary>
/// Same as map_structure, but with only one structure (no combining of multiple structures)
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -97,7 +97,7 @@ namespace Tensorflow
else
{
unique_id = $"{handle_name}_{ops.uid()}";
shared_name = tf.Context.shared_name();
shared_name = null;
}

var attr = new AttrValue();


+ 8
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -60,7 +60,15 @@ namespace Tensorflow.Keras

public void track_variable(IVariableV1 v)
{
if (tf.Context.executing_eagerly())
{
return;
}
var graph = v.Graph;
if(graph is null)
{
graph = get_graph();
}
_GRAPH_VARIABLES[graph.graph_key] = v;
}



+ 57
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -21,10 +21,13 @@ using System.Linq;
using System.Threading;
using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
@@ -349,5 +352,59 @@ namespace Tensorflow.Keras.Engine
{
}

public override void SetAttr(string name, object value)
{
// TODO(Rinne): deal with "_self_setattr_tracking".

value = TrackableDataStructure.sticky_attribute_assignment(this, name, value);
foreach(var val in nest.flatten(value))
{
if(val is Metric)
{
// TODO(Rinne): deal with metrics.
}
}

// TODO(Rinne): deal with "_auto_track_sub_layers".

foreach(var val in nest.flatten(value))
{
if(val is not IVariableV1 variable)
{
continue;
}
if (variable.Trainable)
{
if (_trainable_weights.Contains(variable))
{
continue;
}
_trainable_weights.Add(variable);
}
else
{
if (_non_trainable_weights.Contains(variable))
{
continue;
}
_non_trainable_weights.Add(variable);
}
keras.backend.track_variable(variable);
}

// Directly use the implementation of `Trackable`.
var t = this.GetType();
var field_info = t.GetField(name);
if (field_info is not null)
{
field_info.SetValue(this, value);
}
else
{
CustomizedFields[name] = value;
}
}
}
}

+ 51
- 2
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -1,7 +1,12 @@
using Tensorflow.Keras.ArgsDefinition;
using System.Diagnostics;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine
{
@@ -22,14 +27,16 @@ namespace Tensorflow.Keras.Engine
IOptimizer optimizer;
IVariableV1 _steps_per_execution;
protected bool _is_graph_network;
protected Tensors inputs;
public Tensors inputs;
protected Tensors outputs;
protected List<string> input_names;
public string[] output_names;
IVariableV1 _train_counter;
IVariableV1 _test_counter;
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;
TensorSpec _saved_model_inputs_spec;

public bool IsGraphNetwork => _is_graph_network;
@@ -45,6 +52,38 @@ namespace Tensorflow.Keras.Engine
_init_batch_counters();
}

public void _set_inputs(TensorSpec inputs)
{
_set_save_spec(inputs);
}

internal void _set_save_spec(TensorSpec inputs)
{
if(_saved_model_inputs_spec is not null)
{
return;
}
var input_names = this.input_names;
if(input_names is null || input_names.Count == 0)
{
input_names = compile_utils.create_pseudo_input_names(inputs);
}

var flat_inputs = nest.flatten(inputs);
List<TensorSpec> specs = new();
foreach(var (name, tensor) in zip(input_names, flat_inputs))
{
specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name));
}
var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec;
Debug.Assert(specs is not null);
_saved_model_inputs_spec = packed_specs;
if(this is Sequential && _buildInputShape is null)
{
_buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs);
}
}

internal override void Initialize(LayerArgs args)
{
_init_batch_counters();
@@ -145,6 +184,16 @@ namespace Tensorflow.Keras.Engine
return children;
}

public override void SetAttr(string name, object value)
{
// TODO(Rinne): deal with "_self_setattr_tracking".
//if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v)))
//{
// this._base_model_initialized;
//}
base.SetAttr(name, value);
}


void IModel.set_stopTraining_true()
{


+ 34
- 17
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -1,12 +1,14 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
@@ -17,6 +19,8 @@ using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.Binding;
@@ -190,12 +194,13 @@ namespace Tensorflow.Keras.Saving
Name = config["name"].ToObject<string>()
});
//s.Name = config["name"].ToObject<string>();
if(s.input is null || s.input.Length == 0)
if(s.inputs is null || s.inputs.Length == 0)
{
var first_layer = _get_child_layer_node_ids(model_id)[0];
var input_specs = _infer_inputs(first_layer);
var input_shapes = _infer_inputs(first_layer, true);
var input_shapes = _infer_input_shapes(first_layer);
// `model._set_inputs(input_specs)`
s._set_inputs(input_specs);

// skip the check of input_specs is Dictionary
if (!s.Built)
@@ -220,12 +225,12 @@ namespace Tensorflow.Keras.Saving

private void _set_network_attributes_from_metadata(Model revived_object)
{
var metadata = revived_object.SerializedAttributes["matadata"] as JObject;
if (metadata.ContainsKey("dtype"))
var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData;
if (metadata.DType != TF_DataType.DtInvalid)
{
// TODO(Rinne): set_dtype_policy.
}
revived_object.args.Trainable = metadata["trainable"].Value<bool>();
revived_object.args.Trainable = metadata.Trainable;
}

/// <summary>
@@ -305,6 +310,11 @@ namespace Tensorflow.Keras.Saving
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
// Debug(Rinne)
if(node_id == 11)
{
Console.WriteLine();
}

if (loaded_nodes.ContainsKey(node_id))
{
@@ -472,15 +482,7 @@ namespace Tensorflow.Keras.Saving
}
else
{
var properties = layer.GetType().GetProperties();
foreach(var p in properties)
{
if(p.Name == name as string && p.GetValue(layer) is not null)
{
return;
}
}
Loader.setattr(layer, name, value);
layer.SetAttr(name as string, value);
}
}

@@ -607,7 +609,7 @@ namespace Tensorflow.Keras.Saving

if(build_input_shape is null)
{
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true);
build_input_shape = _infer_input_shapes(node_id);
}

if(build_input_shape is not null)
@@ -633,7 +635,7 @@ namespace Tensorflow.Keras.Saving
/// <param name="layer_node_id"></param>
/// <param name="convert_to_shapes"></param>
/// <returns></returns>
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false)
private TensorSpec _infer_inputs(int layer_node_id)
{
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" });
if(call_fn_id is null)
@@ -648,7 +650,22 @@ namespace Tensorflow.Keras.Saving
}
var call_fn_name = concrete_functions[0];
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name];
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature);
Debug.Assert(structured_input_signature is IEnumerable);
var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator();
first_enumerator.MoveNext();
var first = first_enumerator.Current;
Debug.Assert(first is IEnumerable);
var inputs_enumerator = (first as IEnumerable).GetEnumerator();
inputs_enumerator.MoveNext();
var inputs = inputs_enumerator.Current as TensorSpec;
return inputs;
}

private Shape _infer_input_shapes(int layer_node_id)
{
var inputs = _infer_inputs(layer_node_id);
return nest.map_structure(x => x.shape, inputs);
}

private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child)


+ 1
- 13
src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs View File

@@ -48,19 +48,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
}
else
{
var properties = layer.GetType().GetProperties();
foreach (var p in properties)
{
if ((string)name == p.Name)
{
if(p.GetValue(layer) is not null)
{
return;
}
p.SetValue(layer, value);
return;
}
}
layer.SetAttr(name as string, value);
}
}
}


+ 10
- 5
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -11,7 +11,7 @@ using Tensorflow.Keras.Optimizers;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;
using Tensorflow.Training;
using System.Diagnostics;

namespace Tensorflow.Keras.Saving.SavedModel;

@@ -135,12 +135,17 @@ public partial class KerasSavedModelUtils
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));
var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));

Dictionary<string, Trackable> res = new();
res["variables"] = variables;
res["trainable_variables"] = trainable_variables;
res["non_trainable_variables"] = non_trainable_variables;
res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));
Debug.Assert(variables is Trackable);
Debug.Assert(trainable_variables is Trackable);
Debug.Assert(non_trainable_variables is Trackable);
Debug.Assert(layers is Trackable);
res["variables"] = variables as Trackable;
res["trainable_variables"] = trainable_variables as Trackable;
res["non_trainable_variables"] = non_trainable_variables as Trackable;
res["layers"] = layers as Trackable;

return res;
}


+ 8
- 0
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -165,6 +165,14 @@ namespace Tensorflow.Keras.Utils
}
}

public static bool has_weights(object obj)
{
var obj_type = obj.GetType();
return obj_type.GetField("trainable_weights") is not null &&
obj_type.GetField("non_trainable_weights") is not null &&
obj is not Type;
}

// recusive
static bool uses_keras_history(Tensor op_input)
{


+ 22
- 0
src/TensorFlowNET.Keras/Utils/compile_utils.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;
using Tensorflow.Util;

namespace Tensorflow.Keras.Utils
{
internal static class compile_utils
{
public static List<string> create_pseudo_input_names(TensorSpec inputs)
{
return _create_pseudo_names(inputs, "input_");
}

private static List<string> _create_pseudo_names(TensorSpec tensors, string prefix)
{
// TODO(Rinne): align with tensorflow
return new List<string>() { $"{prefix}1" };
}
}
}

+ 25
- 0
src/TensorFlowNET.Keras/Utils/tf_utils.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;

namespace Tensorflow.Keras.Utils
{
@@ -69,5 +70,29 @@ namespace Tensorflow.Keras.Utils
false_fn: false_fn,
name: name);
}

public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null)
{
throw new NotImplementedException("The function is waited to be implemented in the future.");
}

public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null)
{
var spec = t;
if (!dynamic_batch)
{
return spec;
}
var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name);
var shape = dynamic_batch_spec.shape;
if(shape.rank > 0)
{
var shape_list = shape.as_int_list();
// TODO(Rinne): check if -1 is equivalent to None in python.
shape_list[0] = -1;
dynamic_batch_spec.shape = new Shape(shape_list);
}
return dynamic_batch_spec;
}
}
}

+ 3
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -64,5 +64,8 @@ public class SequentialModelLoad
{
var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func");
model.summary();

var x = tf.ones((2, 10));
var y = model.Apply(x);
}
}

Loading…
Cancel
Save