@@ -30,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||
); | |||
public static class SaveUtil | |||
{ | |||
public static (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
{ | |||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
@@ -119,16 +119,16 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="cache"></param> | |||
/// <param name="object_graph_proto"></param> | |||
private static IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
foreach(var td in tensor_trackables) | |||
{ | |||
// TODO: deal with cache. | |||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
Trackable trackable = null; | |||
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
{ | |||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
@@ -150,12 +150,12 @@ namespace Tensorflow.Checkpoint | |||
return serialized_tensors; | |||
} | |||
private static IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
var trackable = trackable_data.object_to_save; | |||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||
if (call_with_mapped_captures) | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -165,8 +165,7 @@ namespace Tensorflow.Checkpoint | |||
ret_tensor_dict = trackable.serialize_to_tensors(); | |||
} | |||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
foreach(var pair in ret_tensor_dict) | |||
{ | |||
var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
@@ -175,10 +174,12 @@ namespace Tensorflow.Checkpoint | |||
tensor_dict[checkpoint_key] = maybe_tensor; | |||
if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||
foreach(var key in maybe_tensor.Keys) | |||
{ | |||
throw new NotImplementedException(); | |||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||
{ | |||
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||
} | |||
} | |||
if(object_graph_proto is not null) | |||
@@ -202,7 +203,7 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="call_with_mapped_captures"></param> | |||
/// <param name="object_graph_proto"></param> | |||
/// <returns></returns> | |||
private static (Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
{ | |||
Dictionary<Trackable, string> object_names = new(); | |||
@@ -45,12 +45,12 @@ public class TrackableSaver | |||
_graph_view = graph_view; | |||
// TODO: cache when not executing eagerly. | |||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||
} | |||
private (IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
{ | |||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
@@ -69,9 +69,10 @@ public class TrackableSaver | |||
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
{ | |||
serialized_tensors[Trackable.None] = new Dictionary<string, OneOf.OneOf<Tensor, IDictionary<string, Tensor>>>(); | |||
serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||
} | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||
return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
} | |||
@@ -387,6 +388,7 @@ public class CheckpointRestoreCoordinator | |||
/// </summary> | |||
public List<Trackable> AllTrackables => _all_trackables; | |||
public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
// TODO(Rinne): change to weak ref. | |||
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
public int RestoreUid => _restore_uid; | |||
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
@@ -160,12 +160,12 @@ namespace Tensorflow.Checkpoint | |||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
/// <param name="registered_savers"></param> | |||
/// <param name="call_with_mapped_capture"></param> | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
{ | |||
_keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
_restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||
foreach(var pair in serialized_tensors) | |||
{ | |||
@@ -191,16 +191,7 @@ namespace Tensorflow.Checkpoint | |||
foreach(var item in tensor_dict) | |||
{ | |||
var checkpoint_key = item.Key; | |||
IDictionary<string, Tensor> spec_to_tensor; | |||
if(item.Value.TryPickT0(out var t, out var dic)) | |||
{ | |||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||
spec_to_tensor[""] = t; | |||
} | |||
else | |||
{ | |||
spec_to_tensor = dic; | |||
} | |||
var spec_to_tensor = item.Value; | |||
foreach(var spec in spec_to_tensor) | |||
{ | |||
@@ -216,11 +207,19 @@ namespace Tensorflow.Checkpoint | |||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
// skip the process of device name because lack of API. | |||
var host_device = tensor.Device; | |||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||
string host_device; | |||
if (tensor.IsT0) | |||
{ | |||
host_device = tensor.AsT0.Device; | |||
} | |||
else | |||
{ | |||
host_device = tensor.AsT1.device; | |||
} | |||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
} | |||
internal_dict[checkpoint_key][slice_spec] = tensor; | |||
} | |||
@@ -425,7 +424,7 @@ namespace Tensorflow.Checkpoint | |||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
{ | |||
Dictionary<Trackable, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
foreach (var saveable in saveables) | |||
{ | |||
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
@@ -3,6 +3,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Security; | |||
using System.Text; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
@@ -50,7 +51,7 @@ public class CheckpointPosition | |||
{ | |||
_checkpoint.AllTrackables.Add(trackable); | |||
_checkpoint.MatchedProtoIds.Add(_proto_id); | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||
{ | |||
// skip the `logging.warning`. | |||
return false; | |||
@@ -120,6 +120,11 @@ namespace Tensorflow.Contexts | |||
name : | |||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
public string anonymous_name() | |||
{ | |||
return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
} | |||
public void graph_mode(bool isFunc = false) | |||
=> context_switches.Push(false, isFunc); | |||
@@ -6,8 +6,11 @@ | |||
public class DenseSpec : TypeSpec | |||
{ | |||
protected Shape _shape; | |||
public Shape shape => _shape; | |||
public Shape shape | |||
{ | |||
get { return _shape; } | |||
set { _shape = value; } | |||
} | |||
protected TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
@@ -311,7 +311,7 @@ namespace Tensorflow | |||
/// <param name="types">const TF_DataType*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||
public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | |||
SafeStatusHandle status); | |||
@@ -30,6 +30,18 @@ namespace Tensorflow.Operations | |||
} | |||
} | |||
public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||
{ | |||
HandleData handle_data = new(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
{ | |||
Shape = shape.as_proto(), | |||
Dtype = dtype.as_datatype_enum() | |||
}); | |||
return handle_data; | |||
} | |||
public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||
{ | |||
if(target_t is EagerTensor) | |||
@@ -37,7 +49,8 @@ namespace Tensorflow.Operations | |||
target_t.HandleData = handle_data; | |||
return; | |||
} | |||
c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
} | |||
} | |||
} |
@@ -21,6 +21,9 @@ using Tensorflow.Train; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Variables; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using System.Buffers; | |||
namespace Tensorflow | |||
{ | |||
@@ -31,6 +34,7 @@ namespace Tensorflow | |||
{ | |||
public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | |||
{ | |||
// TODO(Rinne): deal with `_handle_graph`. | |||
var value_tensor = ops.convert_to_tensor(value); | |||
return gen_resource_variable_ops.assign_variable_op(handle, | |||
value_tensor, | |||
@@ -78,6 +82,18 @@ namespace Tensorflow | |||
string shared_name, string name, bool graph_mode, Tensor initial_value = null) | |||
{ | |||
var container = ops.get_default_graph().Container; | |||
if(container is null) | |||
{ | |||
container = ""; | |||
} | |||
if (!graph_mode) | |||
{ | |||
if(shared_name is not null) | |||
{ | |||
throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||
} | |||
shared_name = tf.Context.anonymous_name(); | |||
} | |||
var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | |||
dtype: dtype, | |||
shared_name: shared_name, | |||
@@ -95,26 +111,20 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
// We do not want two distinct ResourceVariable objects for the same | |||
// underlying resource in the runtime. | |||
// When in eager mode, explicitly ensure so here. When in graph mode, it's | |||
// ensured by always generating different variable names. | |||
var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
// We create an assert Op instead of checking right away in order to be | |||
// compatible with ASYNC execution mode. Further, since not all devices | |||
// support string tensors, we encode the assertion string in the Op name | |||
/*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||
new[] { exists }, | |||
name: "EagerVariableNameReuse");*/ | |||
var handle_data = new HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType | |||
var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||
if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||
{ | |||
Dtype = dtype.as_datatype_enum(), | |||
Shape = shape.as_proto() | |||
}); | |||
var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||
if (extra_handle_data is not null && extra_handle_data.IsSet) | |||
{ | |||
if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
{ | |||
throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||
$"but saw: '{handle_data}'"); | |||
} | |||
handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||
} | |||
} | |||
_set_handle_shapes_and_types(handle, handle_data, graph_mode); | |||
return handle; | |||
} | |||
@@ -126,24 +136,48 @@ namespace Tensorflow | |||
/// <param name="handle"></param> | |||
/// <param name="handle_data"></param> | |||
/// <param name="graph_mode"></param> | |||
internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
{ | |||
tensor.HandleData = handle_data; | |||
if (!graph_mode) | |||
return; | |||
var size = handle_data.ShapeAndType.Count; | |||
//var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||
//var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||
//var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||
//var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||
//{ | |||
// if (!s.UnknownRank) | |||
// { | |||
// return s.Dim.Select(d => (int)d.Size).ToArray(); | |||
// } | |||
// else | |||
// { | |||
// return Memory<int>.Empty; | |||
// } | |||
//}).ToArray(); | |||
var shapes = new IntPtr[size]; | |||
var types = new DataType[size]; | |||
var ranks = new int[size]; | |||
//List<MemoryHandle> handles = new(); | |||
//IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||
//foreach(var (i, m) in enumerate(converted_shapes)) | |||
//{ | |||
// if(m.IsEmpty) | |||
// { | |||
// shapes_with_ptr[i] = IntPtr.Zero; | |||
// } | |||
// else | |||
// { | |||
// var handle = m.Pin(); | |||
// handles.Add(handle); | |||
// shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||
// } | |||
//} | |||
for (int i = 0; i < size; i++) | |||
{ | |||
var shapeAndType = handle_data.ShapeAndType[i]; | |||
types[i] = shapeAndType.Dtype; | |||
ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||
var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||
} | |||
//Status status = new(); | |||
//// TODO(Rinne): enable it. | |||
//c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||
// shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||
//handles = null; | |||
} | |||
/// <summary> | |||
@@ -330,7 +330,7 @@ namespace Tensorflow { | |||
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | |||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_dependencies_codec | |||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
= pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
/// <summary> | |||
@@ -698,9 +698,13 @@ namespace Tensorflow { | |||
break; | |||
case 10: { | |||
children_.AddEntriesFrom(input, _repeated_children_codec); | |||
dependencies_.AddRange(children_.Except(dependencies_)); | |||
break; | |||
} | |||
case 122: | |||
{ | |||
dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec); | |||
break; | |||
} | |||
case 26: { | |||
slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); | |||
break; | |||
@@ -3,6 +3,7 @@ using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Operations.Activation; | |||
using Tensorflow.Training; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Train | |||
@@ -25,6 +26,13 @@ namespace Tensorflow.Train | |||
} | |||
} | |||
public override void SetAttr(string name, object value) | |||
{ | |||
// TODO(Rinne): deal with `self_setattr_tracking`. | |||
value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||
base.SetAttr(name, value); | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
if(save_type != SaveType.SAVEDMODEL) | |||
@@ -34,6 +42,7 @@ namespace Tensorflow.Train | |||
Dictionary<string, Trackable> functions = new(); | |||
// TODO: process of logs. | |||
// TODO(Rinne): deal with members. | |||
var properties = this.GetType().GetProperties(); | |||
foreach ( var property in properties ) | |||
{ | |||
@@ -45,6 +54,16 @@ namespace Tensorflow.Train | |||
} | |||
} | |||
foreach(var item in CustomizedFields) | |||
{ | |||
var name = item.Key; | |||
var value = item.Value; | |||
if (value is Function or ConcreteFunction) | |||
{ | |||
functions[name] = (Trackable)value; | |||
} | |||
} | |||
// TODO: process the type `core_types.GenericFunction`. | |||
Dictionary<string, Trackable> children = new(); | |||
@@ -42,22 +42,25 @@ namespace Tensorflow | |||
_var_device = var.Device; | |||
_var_shape = var.shape; | |||
Tensor _read_variable_closure(BaseResourceVariable v) | |||
Func<Tensor> _read_variable_closure(BaseResourceVariable v) | |||
{ | |||
tf.device(v.Device); | |||
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
return () => | |||
{ | |||
return null; | |||
} | |||
var x = v.read_value_no_copy(); | |||
tf.device("/device:CPU:0"); | |||
return array_ops.identity(x); | |||
tf.device(v.Device); | |||
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
{ | |||
return null; | |||
} | |||
var x = v.read_value_no_copy(); | |||
tf.device("/device:CPU:0"); | |||
return array_ops.identity(x); | |||
}; | |||
} | |||
this.handle_op = var.Handle; | |||
var tensor = _read_variable_closure(var); | |||
var tensor_creator = _read_variable_closure(var); | |||
var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||
var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device); | |||
_op = var; | |||
specs = new SaveSpec[] { spec }; | |||
this.name = name; | |||
@@ -66,6 +69,7 @@ namespace Tensorflow | |||
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||
{ | |||
var restored_tensor = restored_tensors[0]; | |||
tf.device(_var_device); | |||
restored_tensor = array_ops.identity(restored_tensor); | |||
return resource_variable_ops.shape_safe_assign_variable_handle( | |||
handle_op, _var_shape, restored_tensor); | |||
@@ -14,6 +14,8 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Exceptions; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
@@ -21,8 +23,24 @@ namespace Tensorflow | |||
/// </summary> | |||
public class SaveSpec | |||
{ | |||
private Tensor _tensor; | |||
public Tensor tensor => _tensor; | |||
private Tensor _tensor = null; | |||
private Func<Tensor> _tensor_creator = null; | |||
public Tensor tensor | |||
{ | |||
get | |||
{ | |||
if(_tensor is not null || _tensor_creator is null) | |||
{ | |||
return _tensor; | |||
} | |||
else | |||
{ | |||
return _tensor_creator(); | |||
} | |||
} | |||
} | |||
internal Func<Tensor> TensorCreator => _tensor_creator; | |||
private string _slice_spec; | |||
public string slice_spec => _slice_spec; | |||
@@ -32,13 +50,36 @@ namespace Tensorflow | |||
private TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
private string _device; | |||
public string device => _device; | |||
public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||
{ | |||
_tensor = tensor; | |||
_slice_spec = slice_spec; | |||
_name = name; | |||
_dtype = dtype; | |||
if(device is not null) | |||
{ | |||
_device = device; | |||
} | |||
else | |||
{ | |||
_device = tensor.Device; | |||
} | |||
} | |||
public SaveSpec(Func<Tensor> tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) | |||
{ | |||
_tensor_creator = tensor_creator; | |||
_slice_spec = slice_spec; | |||
_name = name; | |||
if(dtype == TF_DataType.DtInvalid || device is null) | |||
{ | |||
throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided."); | |||
} | |||
_dtype = dtype; | |||
_device = device; | |||
} | |||
} | |||
} |
@@ -1,10 +1,20 @@ | |||
using System; | |||
using System.Diagnostics; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
namespace Tensorflow; | |||
public class RevivedTypes | |||
{ | |||
private static Dictionary<string, ITrackableWrapper> _registered_revived_creator = new(); | |||
static RevivedTypes() | |||
{ | |||
var list_wrapper = new ListWrapper(new Trackable[] { }); | |||
_registered_revived_creator[list_wrapper.Identifier] = list_wrapper; | |||
var dict_wrapper = new DictWrapper(new Dictionary<object, Trackable>()); | |||
_registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper; | |||
} | |||
/// <summary> | |||
/// Create a SavedUserObject from a trackable object. | |||
/// </summary> | |||
@@ -12,13 +22,28 @@ public class RevivedTypes | |||
/// <returns></returns> | |||
public static SavedUserObject? serialize(Trackable obj) | |||
{ | |||
// TODO: complete the implementation. | |||
// TODO(Rinne): complete the implementation. | |||
return null; | |||
} | |||
public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||
public static (Trackable, Action<object, object, object>) deserialize(SavedUserObject proto) | |||
{ | |||
// TODO: complete the implementation. | |||
return null; | |||
if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper)) | |||
{ | |||
return (wrapper.FromProto(proto), (x, y, z) => | |||
{ | |||
if (x is not ITrackableWrapper trackable) | |||
{ | |||
throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}."); | |||
} | |||
Debug.Assert(y is string); | |||
trackable.SetValue(y, z); | |||
} | |||
); | |||
} | |||
else | |||
{ | |||
return (null, null); | |||
} | |||
} | |||
} |
@@ -49,6 +49,7 @@ namespace Tensorflow | |||
var temp = _proto.ToString(); | |||
_export_dir = export_dir; | |||
// TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
// TODO(Rinne): This method is very slow, needs to be accelareted. | |||
_concrete_functions = function_deserialization.load_function_def_library( | |||
meta_graph.GraphDef.Library, _proto); | |||
_restored_concrete_functions = new HashSet<string>(); | |||
@@ -523,7 +524,7 @@ namespace Tensorflow | |||
continue; | |||
} | |||
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
// skip the process of "__call__" | |||
// TODO(Rinne): deal with "__call__" | |||
} | |||
} | |||
@@ -595,13 +596,12 @@ namespace Tensorflow | |||
private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||
{ | |||
// skip the check of proto identifier because of lack of property. | |||
var looked_up = RevivedTypes.deserialize(proto); | |||
if(looked_up is null) | |||
var (trackable, setter) = RevivedTypes.deserialize(proto); | |||
if(trackable is null) | |||
{ | |||
return _recreate_base_user_object(proto, node_id); | |||
} | |||
return (looked_up.Item1, looked_up.Item2); | |||
return (trackable, setter); | |||
} | |||
private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||
@@ -668,13 +668,20 @@ namespace Tensorflow | |||
public static Action<object, object, object> setattr = (x, y, z) => | |||
{ | |||
Debug.Assert(y is string); | |||
var properties = x.GetType().GetProperties(); | |||
foreach(var p in properties) | |||
if(x is Trackable trackable) | |||
{ | |||
trackable.SetAttr(y as string, z); | |||
} | |||
else | |||
{ | |||
if((string)y == p.Name) | |||
var properties = x.GetType().GetProperties(); | |||
foreach (var p in properties) | |||
{ | |||
p.SetValue(x, z); | |||
return; | |||
if ((string)y == p.Name) | |||
{ | |||
p.SetValue(x, z); | |||
return; | |||
} | |||
} | |||
} | |||
// TODO(Rinne): check if the property has been set successfully. | |||
@@ -50,6 +50,10 @@ namespace Tensorflow | |||
} | |||
public static class saveable_object_util | |||
{ | |||
public static string NO_SLICE_SPEC_KEY = ""; | |||
private static HashSet<string> _VARIABLE_OPS = new HashSet<string>(new string[] { | |||
"Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp" | |||
}); | |||
/// <summary> | |||
/// Returns the variables and names that will be used for a Saver. | |||
/// </summary> | |||
@@ -123,19 +127,12 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Tensor op, string name) | |||
{ | |||
if (false) | |||
{ | |||
} | |||
ops.init_scope(); | |||
var variable = ops.convert_to_tensor(op, as_ref: true); | |||
if (variable.dtype.is_ref_dtype()) | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
else | |||
{ | |||
ops.init_scope(); | |||
var variable = ops.convert_to_tensor(op, as_ref: true); | |||
if (variable.dtype.is_ref_dtype()) | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
else | |||
yield return new ResourceVariableSaveable(variable, "", name); | |||
} | |||
yield return new ResourceVariableSaveable(variable, "", name); | |||
} | |||
/// <summary> | |||
@@ -159,7 +156,7 @@ namespace Tensorflow | |||
yield return new ResourceVariableSaveable(variable, "", name); | |||
} | |||
} | |||
else | |||
else if(obj is not IVariableV1) | |||
{ | |||
foreach(var pair in saveable_objects_from_trackable(obj)) | |||
{ | |||
@@ -191,6 +188,30 @@ namespace Tensorflow | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
// Variable | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
throw new ValueError($"Can only save/restore ResourceVariables when " + | |||
$"executing eagerly, got type: {obj.GetType()}."); | |||
} | |||
var variable = ops.convert_to_tensor(obj, as_ref: true); | |||
if (!_tensor_comes_from_variable(variable)) | |||
{ | |||
throw new TypeError($"names_to_saveables must be a dict mapping string " + | |||
$"names to Tensors/Variables. Not a variable: {variable}"); | |||
} | |||
if(variable.op.type == "Variable" || variable.op.type == "VariableV2" || | |||
variable.op.type == "AutoReloadVariable") | |||
{ | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
} | |||
else | |||
{ | |||
yield return new ResourceVariableSaveable(variable, "", name); | |||
} | |||
} | |||
} | |||
/// <summary> | |||
@@ -267,24 +288,14 @@ namespace Tensorflow | |||
foreach (var pair in tensor_dict) | |||
{ | |||
var tensor_name = pair.Key; | |||
var maybe_tensor = pair.Value; | |||
var internal_dict = pair.Value; | |||
local_names.Add(tensor_name); | |||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
IDictionary<string, Tensor> internal_dict; | |||
if (maybe_tensor.TryPickT0(out var tensor, out var dic)) | |||
{ | |||
internal_dict = new Dictionary<string, Tensor>(); | |||
internal_dict[""] = tensor; | |||
} | |||
else | |||
{ | |||
internal_dict = dic; | |||
} | |||
foreach (var item in internal_dict) | |||
{ | |||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
Debug.Assert(item.Value.IsT0); | |||
specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name)); | |||
} | |||
} | |||
return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
@@ -316,9 +327,9 @@ namespace Tensorflow | |||
/// Converts a list of SaveableObjects to a tensor dictionary. | |||
/// </summary> | |||
/// <param name="saveables"></param> | |||
public static Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
public static Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||
{ | |||
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
foreach (var saveable in saveables) | |||
{ | |||
foreach (var spec in saveable.specs) | |||
@@ -326,14 +337,11 @@ namespace Tensorflow | |||
// skip the check that if `spec` is callable. | |||
var name = convert_to_string(spec.name); | |||
var slice_spec = convert_to_string(spec.slice_spec); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).AsT1[slice_spec] = spec.tensor; | |||
} | |||
else | |||
if (string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
tensor_dict[name] = spec.tensor; | |||
slice_spec = NO_SLICE_SPEC_KEY; | |||
} | |||
tensor_dict.SetDefault(name, new Dictionary<string, OneOf<Tensor, SaveSpec>>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec; | |||
} | |||
} | |||
return tensor_dict; | |||
@@ -397,6 +405,11 @@ namespace Tensorflow | |||
{ | |||
return factory(key); | |||
} | |||
private static bool _tensor_comes_from_variable(object v) | |||
{ | |||
return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | |||
} | |||
} | |||
public class SaveableCompatibilityConverter: Trackable | |||
@@ -412,7 +425,7 @@ namespace Tensorflow | |||
public object Obj => _obj; | |||
public IList<MySaveableObject> mySaveables=> _saveables; | |||
public override IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
public override IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||
{ | |||
return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | |||
} | |||
@@ -85,6 +85,72 @@ namespace Tensorflow.Train | |||
_self_saveable_object_factories = value; | |||
} | |||
} | |||
public Dictionary<string, object> CustomizedFields { get; set; } = new Dictionary<string, object>(); | |||
public virtual void SetAttr(string name, object value) | |||
{ | |||
var t = this.GetType(); | |||
var field_info = t.GetField(name); | |||
if(field_info is not null) | |||
{ | |||
field_info.SetValue(this, value); | |||
} | |||
else | |||
{ | |||
CustomizedFields[name] = value; | |||
} | |||
// On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`. | |||
// When adding new members or properties to this class, please add corresponding process to this method. | |||
//switch (name) | |||
//{ | |||
// case "_manual_tracking": | |||
// { | |||
// _manual_tracking = (bool)value; | |||
// break; | |||
// } | |||
// case "_self_saveable_object_factories": | |||
// { | |||
// _self_saveable_object_factories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||
// break; | |||
// } | |||
// case "_self_update_uid": | |||
// { | |||
// _self_update_uid = (int)value; | |||
// break; | |||
// } | |||
// case "_unconditional_checkpoint_dependencies": | |||
// { | |||
// _unconditional_checkpoint_dependencies = (IList<TrackableReference>)value; | |||
// break; | |||
// } | |||
// case "_unconditional_deferred_dependencies": | |||
// { | |||
// _unconditional_deferred_dependencies = (Dictionary<string, IList<CheckpointPosition>>)value; | |||
// break; | |||
// } | |||
// case "_unconditional_dependency_names": | |||
// { | |||
// _unconditional_dependency_names = (IDictionary<string, Trackable>)value; | |||
// break; | |||
// } | |||
// case "SelfSaveableObjectFactories": | |||
// { | |||
// SelfSaveableObjectFactories = (IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>)value; | |||
// break; | |||
// } | |||
// case "UpdateUid": | |||
// { | |||
// UpdateUid = (int)value; | |||
// break; | |||
// } | |||
// default: | |||
// { | |||
// CustomizedAttributes[name] = value; | |||
// break; | |||
// } | |||
// } | |||
} | |||
/// <summary> | |||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
@@ -279,7 +345,7 @@ namespace Tensorflow.Train | |||
/// </summary> | |||
/// <returns></returns> | |||
/// <exception cref="NotImplementedException"></exception> | |||
public virtual IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
public virtual IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> serialize_to_tensors() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -2,6 +2,8 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.IO.Compression; | |||
using System.Linq; | |||
using System.Linq.Expressions; | |||
@@ -25,6 +27,48 @@ namespace Tensorflow.Training | |||
} | |||
} | |||
static class TrackableWrapperUtils | |||
{ | |||
internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto) | |||
{ | |||
if (proto.Identifier != wrapper.Identifier) | |||
{ | |||
return false; | |||
} | |||
if (wrapper.Version < proto.Version.MinConsumer) | |||
{ | |||
return false; | |||
} | |||
if (proto.Version.Producer < wrapper.MinProducerVersion) | |||
{ | |||
return false; | |||
} | |||
foreach (var bad_version in proto.Version.BadConsumers) | |||
{ | |||
if (bad_version == wrapper.Version) | |||
{ | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
internal static bool is_function(Trackable x) | |||
{ | |||
return x is Function or ConcreteFunction; | |||
} | |||
} | |||
public interface ITrackableWrapper | |||
{ | |||
void SetValue(object name, object value); | |||
String Identifier { get; } | |||
int Version { get; } | |||
int MinConsumerVersion { get; } | |||
int MinProducerVersion { get; } | |||
Trackable FromProto(SavedUserObject proto); | |||
} | |||
public abstract class TrackableDataStructure : Trackable | |||
{ | |||
private bool _self_trainable; | |||
@@ -36,7 +80,7 @@ namespace Tensorflow.Training | |||
_self_extra_variables = new List<IVariableV1>(); | |||
} | |||
public abstract IEnumerable<Trackable> Values { get; } | |||
public abstract ICollection<Trackable> Values { get; } | |||
public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | |||
public IEnumerable<ILayer> Layers | |||
{ | |||
@@ -134,7 +178,7 @@ namespace Tensorflow.Training | |||
/// <param name="name"></param> | |||
protected virtual Trackable _track_value(Trackable value, string name) | |||
{ | |||
value = sticky_attribute_assignment(this, name, value); | |||
value = (Trackable)sticky_attribute_assignment(this, name, value); | |||
if(value is IVariableV1) | |||
{ | |||
_self_extra_variables.Add(value as IVariableV1); | |||
@@ -148,44 +192,273 @@ namespace Tensorflow.Training | |||
return value.Value; | |||
} | |||
public static Trackable wrap_or_unwrap(Trackable value) | |||
public static object wrap_or_unwrap(object value) | |||
{ | |||
if(value is NoDependency dependency) | |||
{ | |||
return dependency.Value; | |||
} | |||
if(value is Trackable trackable) | |||
{ | |||
return trackable; | |||
} | |||
else if(value is IDictionary<object, Trackable> obj_dict) | |||
{ | |||
return new DictWrapper(obj_dict); | |||
} | |||
else if(value is IList<Trackable> list) | |||
{ | |||
return new ListWrapper(list); | |||
} | |||
else | |||
{ | |||
return value; | |||
} | |||
} | |||
public static object sticky_attribute_assignment(Trackable trackable, string name, object value) | |||
{ | |||
bool add_dependency = value is not NoDependency; | |||
value = wrap_or_unwrap(value); | |||
if (!add_dependency) | |||
{ | |||
return value; | |||
} | |||
if(value is Trackable trackable_obj) | |||
{ | |||
trackable._track_trackable(trackable_obj, name, true); | |||
} | |||
return value; | |||
} | |||
} | |||
// TODO(Rinne): Add Dict wrapper and Tuple wrapper | |||
public class DictWrapper : TrackableDataStructure, IDictionary<object, Trackable>, ICloneable, ITrackableWrapper | |||
{ | |||
private IDictionary<object, Trackable> _storage; | |||
private bool _non_string_key; | |||
private bool _external_modification; | |||
private IDictionary<object, Trackable> _last_wrapped_dict_snapshot; | |||
public DictWrapper(IDictionary<object, Trackable> wrapped_dict = null) | |||
{ | |||
if(wrapped_dict is not null) | |||
{ | |||
_storage = new Dictionary<object, Trackable>(wrapped_dict); | |||
} | |||
else | |||
{ | |||
_storage = new Dictionary<object, Trackable>(); | |||
} | |||
_update_snapshot(); | |||
} | |||
public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||
public void SetValue(object name, object value) | |||
{ | |||
return new ListWrapper(value); | |||
Debug.Assert(value is Trackable); | |||
this[name] = value as Trackable; | |||
} | |||
public String Identifier => "trackable_dict_wrapper"; | |||
public int Version => 1; | |||
public int MinConsumerVersion => 1; | |||
public int MinProducerVersion => 1; | |||
public Trackable FromProto(SavedUserObject proto) | |||
{ | |||
return new DictWrapper(new Dictionary<object, Trackable>()); | |||
} | |||
public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||
public Trackable this[object key] | |||
{ | |||
return new ListWrapper(value.ToList()); | |||
get | |||
{ | |||
return _storage[key]; | |||
} | |||
set | |||
{ | |||
_check_self_external_modification(); | |||
_maybe_initialize_trackable(); | |||
bool no_dep = value is NoDependency; | |||
if(key is string) | |||
{ | |||
value = _track_value(value, key); | |||
} | |||
else | |||
{ | |||
value = (Trackable)wrap_or_unwrap(value); | |||
if(!no_dep && value is Trackable) | |||
{ | |||
_non_string_key = true; | |||
} | |||
} | |||
_storage[key] = value; | |||
_update_snapshot(); | |||
} | |||
} | |||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||
public ICollection<object> Keys => _storage.Keys; | |||
public override ICollection<Trackable> Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray(); | |||
public void Add(object key, Trackable value) | |||
{ | |||
value = wrap_or_unwrap(value); | |||
trackable._track_trackable(value, name, true); | |||
return value; | |||
_storage[key] = value; | |||
} | |||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) | |||
public bool ContainsKey(object key) | |||
{ | |||
var wrapped_value = wrap_or_unwrap(value); | |||
trackable._track_trackable(wrapped_value, name, true); | |||
return wrapped_value; | |||
return _storage.ContainsKey(key); | |||
} | |||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value) | |||
public bool Remove(object key) | |||
{ | |||
var wrapped_value = wrap_or_unwrap(value); | |||
trackable._track_trackable(wrapped_value, name, true); | |||
return wrapped_value; | |||
_check_self_external_modification(); | |||
var res = _storage.Remove(key); | |||
_update_snapshot(); | |||
return res; | |||
} | |||
} | |||
public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable | |||
public bool TryGetValue(object key, out Trackable value) | |||
{ | |||
return _storage.TryGetValue(key, out value); | |||
} | |||
public int Count => _storage.Count; | |||
public bool IsReadOnly => _storage.IsReadOnly; | |||
public void Add(KeyValuePair<object, Trackable> item) | |||
{ | |||
Add(item.Key, item.Value); | |||
} | |||
public void Clear() | |||
{ | |||
_storage.Clear(); | |||
_update_snapshot(); | |||
} | |||
public bool Contains(KeyValuePair<object, Trackable> item) | |||
{ | |||
return _storage.Contains(item); | |||
} | |||
public void CopyTo(KeyValuePair<object, Trackable>[] array, int arrayIndex) | |||
{ | |||
_storage.CopyTo(array, arrayIndex); | |||
} | |||
public bool Remove(KeyValuePair<object, Trackable> item) | |||
{ | |||
_check_self_external_modification(); | |||
var res = Remove(item); | |||
_update_snapshot(); | |||
return res; | |||
} | |||
public IEnumerator<KeyValuePair<object, Trackable>> GetEnumerator() | |||
{ | |||
return _storage.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); | |||
public object Clone() | |||
{ | |||
var copied = new DictWrapper(_storage); | |||
copied._external_modification = _external_modification; | |||
copied._non_string_key = _non_string_key; | |||
return copied; | |||
} | |||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
{ | |||
_check_self_external_modification(); | |||
if (_non_string_key) | |||
{ | |||
throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" + | |||
$"automatically on attribute assignment). The wrapped dictionary " + | |||
$"contains a non-string key which maps to a trackable object or " + | |||
$"mutable data structure.\n\nIf you don't need this dictionary " + | |||
$"checkpointed, wrap it in a non-trackable " + | |||
$"object; it will be subsequently ignored."); | |||
} | |||
if (_external_modification) | |||
{ | |||
throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " + | |||
$"automatically on attribute assignment). The wrapped dictionary was " + | |||
$"modified outside the wrapper (its final value was {this}, its value" + | |||
$" when a checkpoint dependency was added was " + | |||
$"{this._last_wrapped_dict_snapshot}), which breaks " + | |||
$"restoration on object creation.\n\nIf you don't need this " + | |||
$"dictionary checkpointed, wrap it in a " + | |||
$"non-trackable object; it will be subsequently ignored."); | |||
} | |||
Debug.Assert(!Dirty); | |||
var children = base._trackable_children(save_type, cache); | |||
if(save_type == SaveType.SAVEDMODEL) | |||
{ | |||
foreach(var item in _storage) | |||
{ | |||
var key = item.Key; | |||
var value = item.Value; | |||
if (TrackableWrapperUtils.is_function(value)) | |||
{ | |||
Debug.Assert(key is string); | |||
children[key as string] = value; | |||
} | |||
} | |||
} | |||
return children; | |||
} | |||
protected Trackable _track_value(Trackable value, object name) | |||
{ | |||
bool string_key = name is string; | |||
if (!string_key) | |||
{ | |||
name = "-non_string_key"; | |||
} | |||
try | |||
{ | |||
bool no_dependency = value is NoDependency; | |||
value = base._track_value(value, name as string); | |||
if(!(string_key || no_dependency)) | |||
{ | |||
_non_string_key = true; | |||
} | |||
return value; | |||
} | |||
catch (ValueError) | |||
{ | |||
return (Trackable)sticky_attribute_assignment(this, name as string, value); | |||
} | |||
} | |||
private bool Dirty => _external_modification || _non_string_key; | |||
private void _check_self_external_modification() | |||
{ | |||
if (Dirty) | |||
{ | |||
return; | |||
} | |||
if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot)) | |||
{ | |||
_external_modification = true; | |||
_last_wrapped_dict_snapshot = null; | |||
} | |||
} | |||
private void _update_snapshot() | |||
{ | |||
// TODO(Rinne): deal with attribute_sentinel. | |||
if (Dirty) return; | |||
_last_wrapped_dict_snapshot = new Dictionary<object, Trackable>(_storage); | |||
} | |||
} | |||
public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable, ITrackableWrapper | |||
{ | |||
private IList<Trackable> _storage; | |||
private bool _non_append_mutation_value; | |||
@@ -198,11 +471,51 @@ namespace Tensorflow.Training | |||
/// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | |||
public ListWrapper(IList<Trackable> wrapped_list) | |||
{ | |||
_storage = wrapped_list; | |||
_storage = new List<Trackable>(wrapped_list); | |||
_non_append_mutation_value = _external_modification_value = false; | |||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||
} | |||
public string Identifier => "trackable_list_wrapper"; | |||
public int Version => 1; | |||
public int MinConsumerVersion => 1; | |||
public int MinProducerVersion => 1; | |||
public Trackable FromProto(SavedUserObject proto) | |||
{ | |||
if(TrackableWrapperUtils.ShouldLoad(this, proto)) | |||
{ | |||
return new ListWrapper(new Trackable[] { }); | |||
} | |||
else | |||
{ | |||
return null; | |||
} | |||
} | |||
public void SetValue(object name, object value) | |||
{ | |||
Debug.Assert(name is string); | |||
if(int.TryParse(name as string, out var index)) | |||
{ | |||
if(value is not Trackable trackable) | |||
{ | |||
throw new TypeError("Cannot set an object which is not trackable to ListWrapper."); | |||
} | |||
if(Count <= index) | |||
{ | |||
Add(trackable); | |||
} | |||
else | |||
{ | |||
this[index] = trackable; | |||
} | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("Encounter an unexpected behavior in <ListWrapper.SetAttr>, please " + | |||
"submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
protected bool NonAppendMuation { | |||
get => _non_append_mutation_value; | |||
set | |||
@@ -222,7 +535,7 @@ namespace Tensorflow.Training | |||
} | |||
} | |||
public override IEnumerable<Trackable> Values => this; | |||
public override ICollection<Trackable> Values => this; | |||
public bool IsReadOnly { get => _storage.IsReadOnly; } | |||
/// <summary> | |||
@@ -239,7 +552,7 @@ namespace Tensorflow.Training | |||
private void update_snapshot() | |||
{ | |||
// TODO: deal with `attribute_sentinel`. | |||
// TODO(Rinne): deal with `attribute_sentinel`. | |||
if (_external_modification_value || _non_append_mutation_value) return; | |||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||
} | |||
@@ -286,9 +599,9 @@ namespace Tensorflow.Training | |||
{ | |||
base._track_value(value, name); | |||
} | |||
catch(ValueError ex) | |||
catch(ValueError) | |||
{ | |||
value = sticky_attribute_assignment(this, name, value); | |||
value = (Trackable)sticky_attribute_assignment(this, name, value); | |||
} | |||
return value; | |||
} | |||
@@ -343,7 +656,11 @@ namespace Tensorflow.Training | |||
update_snapshot(); | |||
} | |||
public void Clear() => _storage.Clear(); | |||
public void Clear() | |||
{ | |||
_storage.Clear(); | |||
update_snapshot(); | |||
} | |||
public bool Contains(Trackable item) => _storage.Contains(item); | |||
@@ -519,6 +519,14 @@ namespace Tensorflow.Util | |||
return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
} | |||
public static T2 map_structure<T1, T2>(Func<T1, T2> func, T1 structure) where T2: class | |||
{ | |||
var flat_structure = flatten(structure); | |||
var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); | |||
return pack_sequence_as(structure, mapped_flat_structure) as T2; | |||
} | |||
/// <summary> | |||
/// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
/// </summary> | |||
@@ -97,7 +97,7 @@ namespace Tensorflow | |||
else | |||
{ | |||
unique_id = $"{handle_name}_{ops.uid()}"; | |||
shared_name = tf.Context.shared_name(); | |||
shared_name = null; | |||
} | |||
var attr = new AttrValue(); | |||
@@ -60,7 +60,15 @@ namespace Tensorflow.Keras | |||
public void track_variable(IVariableV1 v) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return; | |||
} | |||
var graph = v.Graph; | |||
if(graph is null) | |||
{ | |||
graph = get_graph(); | |||
} | |||
_GRAPH_VARIABLES[graph.graph_key] = v; | |||
} | |||
@@ -21,10 +21,13 @@ using System.Linq; | |||
using System.Threading; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
@@ -349,5 +352,59 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
} | |||
public override void SetAttr(string name, object value) | |||
{ | |||
// TODO(Rinne): deal with "_self_setattr_tracking". | |||
value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); | |||
foreach(var val in nest.flatten(value)) | |||
{ | |||
if(val is Metric) | |||
{ | |||
// TODO(Rinne): deal with metrics. | |||
} | |||
} | |||
// TODO(Rinne): deal with "_auto_track_sub_layers". | |||
foreach(var val in nest.flatten(value)) | |||
{ | |||
if(val is not IVariableV1 variable) | |||
{ | |||
continue; | |||
} | |||
if (variable.Trainable) | |||
{ | |||
if (_trainable_weights.Contains(variable)) | |||
{ | |||
continue; | |||
} | |||
_trainable_weights.Add(variable); | |||
} | |||
else | |||
{ | |||
if (_non_trainable_weights.Contains(variable)) | |||
{ | |||
continue; | |||
} | |||
_non_trainable_weights.Add(variable); | |||
} | |||
keras.backend.track_variable(variable); | |||
} | |||
// Directly use the implementation of `Trackable`. | |||
var t = this.GetType(); | |||
var field_info = t.GetField(name); | |||
if (field_info is not null) | |||
{ | |||
field_info.SetValue(this, value); | |||
} | |||
else | |||
{ | |||
CustomizedFields[name] = value; | |||
} | |||
} | |||
} | |||
} |
@@ -1,7 +1,12 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using System.Diagnostics; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -22,14 +27,16 @@ namespace Tensorflow.Keras.Engine | |||
IOptimizer optimizer; | |||
IVariableV1 _steps_per_execution; | |||
protected bool _is_graph_network; | |||
protected Tensors inputs; | |||
public Tensors inputs; | |||
protected Tensors outputs; | |||
protected List<string> input_names; | |||
public string[] output_names; | |||
IVariableV1 _train_counter; | |||
IVariableV1 _test_counter; | |||
IVariableV1 _predict_counter; | |||
bool _base_model_initialized; | |||
bool stop_training; | |||
TensorSpec _saved_model_inputs_spec; | |||
public bool IsGraphNetwork => _is_graph_network; | |||
@@ -45,6 +52,38 @@ namespace Tensorflow.Keras.Engine | |||
_init_batch_counters(); | |||
} | |||
public void _set_inputs(TensorSpec inputs) | |||
{ | |||
_set_save_spec(inputs); | |||
} | |||
internal void _set_save_spec(TensorSpec inputs) | |||
{ | |||
if(_saved_model_inputs_spec is not null) | |||
{ | |||
return; | |||
} | |||
var input_names = this.input_names; | |||
if(input_names is null || input_names.Count == 0) | |||
{ | |||
input_names = compile_utils.create_pseudo_input_names(inputs); | |||
} | |||
var flat_inputs = nest.flatten(inputs); | |||
List<TensorSpec> specs = new(); | |||
foreach(var (name, tensor) in zip(input_names, flat_inputs)) | |||
{ | |||
specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); | |||
} | |||
var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; | |||
Debug.Assert(specs is not null); | |||
_saved_model_inputs_spec = packed_specs; | |||
if(this is Sequential && _buildInputShape is null) | |||
{ | |||
_buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs); | |||
} | |||
} | |||
internal override void Initialize(LayerArgs args) | |||
{ | |||
_init_batch_counters(); | |||
@@ -145,6 +184,16 @@ namespace Tensorflow.Keras.Engine | |||
return children; | |||
} | |||
public override void SetAttr(string name, object value) | |||
{ | |||
// TODO(Rinne): deal with "_self_setattr_tracking". | |||
//if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) | |||
//{ | |||
// this._base_model_initialized; | |||
//} | |||
base.SetAttr(name, value); | |||
} | |||
void IModel.set_stopTraining_true() | |||
{ | |||
@@ -1,12 +1,14 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Reflection; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
@@ -17,6 +19,8 @@ using Tensorflow.Keras.Saving.SavedModel; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Util; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
using static Tensorflow.ApiDef.Types; | |||
using static Tensorflow.Binding; | |||
@@ -190,12 +194,13 @@ namespace Tensorflow.Keras.Saving | |||
Name = config["name"].ToObject<string>() | |||
}); | |||
//s.Name = config["name"].ToObject<string>(); | |||
if(s.input is null || s.input.Length == 0) | |||
if(s.inputs is null || s.inputs.Length == 0) | |||
{ | |||
var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
var input_specs = _infer_inputs(first_layer); | |||
var input_shapes = _infer_inputs(first_layer, true); | |||
var input_shapes = _infer_input_shapes(first_layer); | |||
// `model._set_inputs(input_specs)` | |||
s._set_inputs(input_specs); | |||
// skip the check of input_specs is Dictionary | |||
if (!s.Built) | |||
@@ -220,12 +225,12 @@ namespace Tensorflow.Keras.Saving | |||
private void _set_network_attributes_from_metadata(Model revived_object) | |||
{ | |||
var metadata = revived_object.SerializedAttributes["matadata"] as JObject; | |||
if (metadata.ContainsKey("dtype")) | |||
var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData; | |||
if (metadata.DType != TF_DataType.DtInvalid) | |||
{ | |||
// TODO(Rinne): set_dtype_policy. | |||
} | |||
revived_object.args.Trainable = metadata["trainable"].Value<bool>(); | |||
revived_object.args.Trainable = metadata.Trainable; | |||
} | |||
/// <summary> | |||
@@ -305,6 +310,11 @@ namespace Tensorflow.Keras.Saving | |||
private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
{ | |||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
// Debug(Rinne) | |||
if(node_id == 11) | |||
{ | |||
Console.WriteLine(); | |||
} | |||
if (loaded_nodes.ContainsKey(node_id)) | |||
{ | |||
@@ -472,15 +482,7 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
else | |||
{ | |||
var properties = layer.GetType().GetProperties(); | |||
foreach(var p in properties) | |||
{ | |||
if(p.Name == name as string && p.GetValue(layer) is not null) | |||
{ | |||
return; | |||
} | |||
} | |||
Loader.setattr(layer, name, value); | |||
layer.SetAttr(name as string, value); | |||
} | |||
} | |||
@@ -607,7 +609,7 @@ namespace Tensorflow.Keras.Saving | |||
if(build_input_shape is null) | |||
{ | |||
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||
build_input_shape = _infer_input_shapes(node_id); | |||
} | |||
if(build_input_shape is not null) | |||
@@ -633,7 +635,7 @@ namespace Tensorflow.Keras.Saving | |||
/// <param name="layer_node_id"></param> | |||
/// <param name="convert_to_shapes"></param> | |||
/// <returns></returns> | |||
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||
private TensorSpec _infer_inputs(int layer_node_id) | |||
{ | |||
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||
if(call_fn_id is null) | |||
@@ -648,7 +650,22 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
var call_fn_name = concrete_functions[0]; | |||
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature); | |||
Debug.Assert(structured_input_signature is IEnumerable); | |||
var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator(); | |||
first_enumerator.MoveNext(); | |||
var first = first_enumerator.Current; | |||
Debug.Assert(first is IEnumerable); | |||
var inputs_enumerator = (first as IEnumerable).GetEnumerator(); | |||
inputs_enumerator.MoveNext(); | |||
var inputs = inputs_enumerator.Current as TensorSpec; | |||
return inputs; | |||
} | |||
private Shape _infer_input_shapes(int layer_node_id) | |||
{ | |||
var inputs = _infer_inputs(layer_node_id); | |||
return nest.map_structure(x => x.shape, inputs); | |||
} | |||
private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
@@ -48,19 +48,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
} | |||
else | |||
{ | |||
var properties = layer.GetType().GetProperties(); | |||
foreach (var p in properties) | |||
{ | |||
if ((string)name == p.Name) | |||
{ | |||
if(p.GetValue(layer) is not null) | |||
{ | |||
return; | |||
} | |||
p.SetValue(layer, value); | |||
return; | |||
} | |||
} | |||
layer.SetAttr(name as string, value); | |||
} | |||
} | |||
} | |||
@@ -11,7 +11,7 @@ using Tensorflow.Keras.Optimizers; | |||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Training; | |||
using System.Diagnostics; | |||
namespace Tensorflow.Keras.Saving.SavedModel; | |||
@@ -135,12 +135,17 @@ public partial class KerasSavedModelUtils | |||
if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
})); | |||
var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
Dictionary<string, Trackable> res = new(); | |||
res["variables"] = variables; | |||
res["trainable_variables"] = trainable_variables; | |||
res["non_trainable_variables"] = non_trainable_variables; | |||
res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
Debug.Assert(variables is Trackable); | |||
Debug.Assert(trainable_variables is Trackable); | |||
Debug.Assert(non_trainable_variables is Trackable); | |||
Debug.Assert(layers is Trackable); | |||
res["variables"] = variables as Trackable; | |||
res["trainable_variables"] = trainable_variables as Trackable; | |||
res["non_trainable_variables"] = non_trainable_variables as Trackable; | |||
res["layers"] = layers as Trackable; | |||
return res; | |||
} | |||
@@ -165,6 +165,14 @@ namespace Tensorflow.Keras.Utils | |||
} | |||
} | |||
public static bool has_weights(object obj) | |||
{ | |||
var obj_type = obj.GetType(); | |||
return obj_type.GetField("trainable_weights") is not null && | |||
obj_type.GetField("non_trainable_weights") is not null && | |||
obj is not Type; | |||
} | |||
// recusive | |||
static bool uses_keras_history(Tensor op_input) | |||
{ | |||
@@ -0,0 +1,22 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
internal static class compile_utils | |||
{ | |||
public static List<string> create_pseudo_input_names(TensorSpec inputs) | |||
{ | |||
return _create_pseudo_names(inputs, "input_"); | |||
} | |||
private static List<string> _create_pseudo_names(TensorSpec tensors, string prefix) | |||
{ | |||
// TODO(Rinne): align with tensorflow | |||
return new List<string>() { $"{prefix}1" }; | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Framework.Models; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
@@ -69,5 +70,29 @@ namespace Tensorflow.Keras.Utils | |||
false_fn: false_fn, | |||
name: name); | |||
} | |||
public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null) | |||
{ | |||
throw new NotImplementedException("The function is waited to be implemented in the future."); | |||
} | |||
public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null) | |||
{ | |||
var spec = t; | |||
if (!dynamic_batch) | |||
{ | |||
return spec; | |||
} | |||
var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name); | |||
var shape = dynamic_batch_spec.shape; | |||
if(shape.rank > 0) | |||
{ | |||
var shape_list = shape.as_int_list(); | |||
// TODO(Rinne): check if -1 is equivalent to None in python. | |||
shape_list[0] = -1; | |||
dynamic_batch_spec.shape = new Shape(shape_list); | |||
} | |||
return dynamic_batch_spec; | |||
} | |||
} | |||
} |
@@ -64,5 +64,8 @@ public class SequentialModelLoad | |||
{ | |||
var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | |||
model.summary(); | |||
var x = tf.ones((2, 10)); | |||
var y = model.Apply(x); | |||
} | |||
} |