@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
@@ -23,6 +25,26 @@ namespace Tensorflow | |||||
public class CompatApi | public class CompatApi | ||||
{ | { | ||||
public CompatV1Api v1 { get; } = new CompatV1Api(); | public CompatV1Api v1 { get; } = new CompatV1Api(); | ||||
internal string as_text(string bytes_or_text, Encoding? encoding = null) | |||||
{ | |||||
if(encoding is null) encoding = Encoding.UTF8; | |||||
return bytes_or_text; | |||||
} | |||||
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) | |||||
{ | |||||
if(encoding is null) encoding = Encoding.UTF8; | |||||
return encoding.GetString(bytes_or_text); | |||||
} | |||||
internal string as_str(string bytes_or_text, Encoding? encoding = null) | |||||
{ | |||||
return as_text(bytes_or_text, encoding); | |||||
} | |||||
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) | |||||
{ | |||||
return as_text(bytes_or_text, encoding); | |||||
} | |||||
} | } | ||||
public bool executing_eagerly() | public bool executing_eagerly() | ||||
@@ -0,0 +1,152 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
namespace Tensorflow.Checkpoint; | |||||
public static class CheckPointUtils | |||||
{ | |||||
private static string _ESCAPE_CHAR = "."; | |||||
public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>, | |||||
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||||
{ | |||||
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
Dictionary<Trackable, string> object_names = new(); | |||||
foreach (var pair in node_paths) | |||||
{ | |||||
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
} | |||||
Dictionary<Trackable, int> node_ids = new(); | |||||
for (int i = 0; i < trackable_objects.Count; i++) | |||||
{ | |||||
node_ids[trackable_objects[i]] = i; | |||||
} | |||||
var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
return (trackable_objects, node_paths, node_ids, slot_variables, object_names); | |||||
} | |||||
public static | |||||
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
serialize_slot_variables(IEnumerable<Trackable> trackable_objects, | |||||
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names) | |||||
{ | |||||
var non_slot_objects = trackable_objects.ToList(); | |||||
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
slot_variables = new(); | |||||
foreach (var trackable in non_slot_objects) | |||||
{ | |||||
if (trackable is not Optimizer) | |||||
{ | |||||
continue; | |||||
} | |||||
var optim = (Optimizer)trackable; | |||||
var slot_names = optim.get_slot_names(); | |||||
foreach (var slot_name in slot_names) | |||||
{ | |||||
for (int original_variable_node_id = 0; | |||||
original_variable_node_id < non_slot_objects.Count; | |||||
original_variable_node_id++) | |||||
{ | |||||
var original_variable = non_slot_objects[original_variable_node_id]; | |||||
IVariableV1 slot_variable; | |||||
if (original_variable is not IVariableV1) | |||||
{ | |||||
slot_variable = null; | |||||
} | |||||
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); | |||||
if(slot_variable is null) continue; | |||||
// There're some problems about the inherits of `Variable` and `Trackable`. | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} | |||||
return slot_variables; | |||||
} | |||||
public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map) | |||||
{ | |||||
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) | |||||
{ | |||||
return trackable; | |||||
} | |||||
else | |||||
{ | |||||
return possible_res; | |||||
} | |||||
} | |||||
public static string get_full_name(Trackable variable) | |||||
{ | |||||
// TODO: This state is not correct, the whole framework need to be updated in the future. | |||||
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) | |||||
{ | |||||
return ""; | |||||
} | |||||
// skip the check of attribute `_save_slice_info` . | |||||
// TODO: Need to be revised!!! | |||||
Debug.Assert(variable is BaseResourceVariable); | |||||
return ((BaseResourceVariable)variable).Name; | |||||
} | |||||
public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) | |||||
{ | |||||
HashSet<int> checkpointed_trackables = new(); | |||||
Dictionary<int, HashSet<int>> parents = new(); | |||||
for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
{ | |||||
var object_proto = object_graph_proto.Nodes[i]; | |||||
// skip the process of registered saver. | |||||
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || | |||||
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) | |||||
{ | |||||
checkpointed_trackables.Add(i); | |||||
} | |||||
foreach (var child_proto in object_proto.Children) | |||||
{ | |||||
var child = child_proto.NodeId; | |||||
if (!parents.ContainsKey(child)) | |||||
{ | |||||
parents[child] = new HashSet<int>(); | |||||
} | |||||
parents[child].Add(i); | |||||
} | |||||
} | |||||
Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable()); | |||||
while (to_visit.Count > 0) | |||||
{ | |||||
var trackable = to_visit.Dequeue(); | |||||
if (!parents.ContainsKey(trackable)) continue; | |||||
var current_parents = parents[trackable]; | |||||
foreach (var parent in current_parents) | |||||
{ | |||||
checkpointed_trackables.Add(parent); | |||||
if (parents.ContainsKey(parent)) | |||||
{ | |||||
to_visit.Enqueue(parent); | |||||
} | |||||
} | |||||
parents.Remove(trackable); | |||||
} | |||||
// TODO: Complete it after supporting checkpoint. | |||||
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
// { | |||||
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||||
// } | |||||
} | |||||
} |
@@ -0,0 +1,5 @@ | |||||
namespace Tensorflow.Checkpoint; | |||||
public record class CheckpointOptions( | |||||
string? experimental_io_device = null, | |||||
bool experimental_enable_async_checkpoint = false); |
@@ -0,0 +1,64 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Serilog.Debugging; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Checkpoint; | |||||
public class ObjectGraphView: TrackableView, ICloneable | |||||
{ | |||||
protected IEnumerable<TrackableReference>? _attached_dependencies; | |||||
// TODO: attached_dependencies | |||||
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root) | |||||
{ | |||||
_attached_dependencies = attached_dependencies; | |||||
} | |||||
public object Clone() | |||||
{ | |||||
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ | |||||
return new ObjectGraphView(Root, _attached_dependencies); | |||||
} | |||||
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | |||||
List<TrackableReference> res = base.children(obj, save_type, serialization_cache) | |||||
.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||||
// Check the reference, not value. | |||||
if (obj == Root && _attached_dependencies is not null) | |||||
{ | |||||
res.AddRange(_attached_dependencies); | |||||
} | |||||
return res; | |||||
} | |||||
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | |||||
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); | |||||
} | |||||
public IEnumerable<TrackableReference>? AttachedDependencies | |||||
{ | |||||
get => _attached_dependencies; | |||||
} | |||||
public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
{ | |||||
return base._descendants_with_paths(); | |||||
} | |||||
// TODO: complete the implementation | |||||
public void serialize_object_graph(object? saveables_cache = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
// TODO: complete the implementation | |||||
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} |
@@ -0,0 +1,255 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
namespace Tensorflow.Checkpoint | |||||
{ | |||||
internal record class TrackableData( | |||||
// A trackable in the root Trackable object graph. | |||||
Trackable trackable, | |||||
// The index at which the Trackable appears in TrackableObjectGraph.nodes. | |||||
int node_id, | |||||
// The BFS-generated path from the root object / used to generate readable checkpoint keys. | |||||
string object_name, | |||||
// A list of ObjectReference for each child connected to this Trackable. | |||||
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto, | |||||
// A list of SlotVariableReference to save to the object (only valid for Optimizer objects). | |||||
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto, | |||||
// The object to save to checkpoint. Usually this is the same as `trackable`, | |||||
// but can differ when the the caller wants to specify a different object to | |||||
// save. For example, when saving checkpoints asynchronously, variables are | |||||
// copied to the CPU. `object_to_save` is set as the copied variable. | |||||
Trackable object_to_save | |||||
); | |||||
public static class SaveUtil | |||||
{ | |||||
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||||
{ | |||||
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||||
var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); | |||||
var object_graph_proto = fill_object_graph_proto(trackable_data); | |||||
var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); | |||||
var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); | |||||
Dictionary<Tensor, object> feed_additions; | |||||
if(cache is null) | |||||
{ | |||||
feed_additions = null; | |||||
serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, | |||||
cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
else | |||||
{ | |||||
feed_additions = null; | |||||
// TODO: deal with cache. | |||||
throw new NotFiniteNumberException(); | |||||
} | |||||
CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | |||||
} | |||||
private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||||
{ | |||||
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
Dictionary<Trackable, string> object_names = new(); | |||||
foreach(var pair in node_paths) | |||||
{ | |||||
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
} | |||||
Dictionary<Trackable, int> node_ids = new(); | |||||
for(int i = 0; i < trackable_objects.Count; i++) | |||||
{ | |||||
node_ids[trackable_objects[i]] = i; | |||||
} | |||||
var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
List<TrackableData> trackable_data = new(); | |||||
foreach(var trackable in trackable_objects) | |||||
{ | |||||
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new(); | |||||
foreach(var child in graph_view.list_children(trackable)) | |||||
{ | |||||
children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
{ | |||||
NodeId = node_ids[child.Refer], | |||||
LocalName = child.Name | |||||
}); | |||||
} | |||||
slot_variables.TryGetValue(trackable, out var slot_variable); | |||||
trackable_data.Add(new TrackableData( | |||||
trackable: trackable, | |||||
node_id: node_ids[trackable], | |||||
object_name: object_names[trackable], | |||||
children_proto: children_proto, | |||||
slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(), | |||||
object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) | |||||
)); | |||||
} | |||||
return (trackable_data, node_ids); | |||||
} | |||||
private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data) | |||||
{ | |||||
TrackableObjectGraph object_graph_proto = new(); | |||||
for(int i = 0; i < trackable_data.Count; i++) | |||||
{ | |||||
var td = trackable_data[i]; | |||||
Debug.Assert(td.node_id == i); | |||||
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||||
} | |||||
return object_graph_proto; | |||||
} | |||||
/// <summary> | |||||
/// Creates dictionary of tensors to checkpoint, and updates the proto. | |||||
/// </summary> | |||||
/// <param name="tensor_trackables"></param> | |||||
/// <param name="node_ids"></param> | |||||
/// <param name="call_with_mapped_captures"></param> | |||||
/// <param name="cache"></param> | |||||
/// <param name="object_graph_proto"></param> | |||||
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||||
{ | |||||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
foreach(var td in tensor_trackables) | |||||
{ | |||||
// TODO: deal with cache. | |||||
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||||
Trackable trackable = null; | |||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||||
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||||
{ | |||||
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||||
} | |||||
else | |||||
{ | |||||
tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); | |||||
trackable = td.object_to_save; | |||||
} | |||||
if(trackable is not null) | |||||
{ | |||||
serialized_tensors[trackable] = tensor_dict; | |||||
} | |||||
else | |||||
{ | |||||
serialized_tensors[Trackable.None] = tensor_dict; | |||||
} | |||||
} | |||||
return serialized_tensors; | |||||
} | |||||
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
{ | |||||
var trackable = trackable_data.object_to_save; | |||||
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||||
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||||
if (call_with_mapped_captures) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
ret_tensor_dict = trackable.serialize_to_tensors(); | |||||
} | |||||
// TODO: deal with the type `SaveSpce` (currently it will never be it). | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
foreach(var pair in ret_tensor_dict) | |||||
{ | |||||
var local_name = TrackableUtils.escape_local_name(pair.Key); | |||||
var maybe_tensor = pair.Value; | |||||
var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); | |||||
tensor_dict[checkpoint_key] = maybe_tensor; | |||||
if(maybe_tensor.GetValueA() is SaveSpec) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
} | |||||
if(object_graph_proto is not null) | |||||
{ | |||||
object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
{ | |||||
Name = local_name, | |||||
CheckpointKey = checkpoint_key, | |||||
FullName = CheckPointUtils.get_full_name(trackable) | |||||
}); | |||||
} | |||||
} | |||||
return tensor_dict; | |||||
} | |||||
/// <summary> | |||||
/// Gets tensors to serialize from a Trackable with legacy SaveableObjects. | |||||
/// </summary> | |||||
/// <param name="trackable_data"></param> | |||||
/// <param name="node_ids"></param> | |||||
/// <param name="call_with_mapped_captures"></param> | |||||
/// <param name="object_graph_proto"></param> | |||||
/// <returns></returns> | |||||
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
{ | |||||
Dictionary<Trackable, string> object_names = new(); | |||||
object_names[trackable_data.trackable] = trackable_data.object_name; | |||||
Dictionary<Trackable, Trackable> object_map = new(); | |||||
object_map[trackable_data.trackable] = trackable_data.object_to_save; | |||||
var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); | |||||
var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, | |||||
call_with_mapped_captures, saveables_cache: null); | |||||
var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); | |||||
return (trackable, trackable.serialize_to_tensors()); | |||||
} | |||||
private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto) | |||||
{ | |||||
Dictionary<string, IDictionary<string, Trackable>> registered_savers = new(); | |||||
foreach(var pair in registered_trackables) | |||||
{ | |||||
foreach(var td in pair.Value) | |||||
{ | |||||
if (registered_savers.ContainsKey(pair.Key)) | |||||
{ | |||||
registered_savers[pair.Key] = new Dictionary<string, Trackable>(); | |||||
} | |||||
else | |||||
{ | |||||
registered_savers[pair.Key][td.object_name] = td.object_to_save; | |||||
} | |||||
var object_proto = object_graph_proto.Nodes[td.node_id]; | |||||
// TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. | |||||
} | |||||
} | |||||
return registered_savers; | |||||
} | |||||
private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data) | |||||
{ | |||||
List<TrackableData> tensor_trackables = new(); | |||||
List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. | |||||
Dictionary<string, IList<TrackableData>> registered_trackables = new(); | |||||
foreach(var td in trackable_data) | |||||
{ | |||||
// TODO: deal with registration. | |||||
tensor_trackables.Add(td); | |||||
} | |||||
return (tensor_trackables, py_state_trackables, registered_trackables); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,222 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Exceptions; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using static Tensorflow.Binding; | |||||
using Google.Protobuf; | |||||
namespace Tensorflow.Checkpoint; | |||||
public static class SaveUtilV1 | |||||
{ | |||||
public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||||
IDictionary<Trackable, Trackable>? object_map = null) | |||||
{ | |||||
// According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | |||||
// till now only internal registrations are allowed. So, we won't return a saver in this function. | |||||
// The implementation of this function should be updated if tensorflow update it. | |||||
Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new(); | |||||
foreach (var pair in object_names) | |||||
{ | |||||
var trackable = pair.Key; | |||||
var object_name = pair.Value; | |||||
var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
// skip the registration process. | |||||
List<CheckpointFactoryData> current_list = new(); | |||||
foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) | |||||
{ | |||||
// treat name as key_suffix. | |||||
var name = name_and_factory.Key; | |||||
var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); | |||||
current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); | |||||
} | |||||
checkpoint_factory_map[trackable] = current_list; | |||||
} | |||||
return (checkpoint_factory_map, null); | |||||
} | |||||
public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||||
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||||
object? saveables_cache = null) | |||||
{ | |||||
if (to_graph is not null) | |||||
{ | |||||
var g = to_graph.as_default(); | |||||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
object_map, call_with_mapped_captures, saveables_cache); | |||||
tf.device("/cpu:0"); | |||||
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
g.Exit(); | |||||
return (named_saveable_objects, registered_savers); | |||||
} | |||||
else | |||||
{ | |||||
using (new ops.NullContextManager()) | |||||
{ | |||||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
object_map, call_with_mapped_captures, saveables_cache); | |||||
tf.device("/cpu:0"); | |||||
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
return (named_saveable_objects, registered_savers); | |||||
} | |||||
} | |||||
} | |||||
public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||||
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
{ | |||||
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
Dictionary<Trackable, string> object_names = new(); | |||||
foreach (var pair in node_paths) | |||||
{ | |||||
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
} | |||||
Dictionary<Trackable, int> node_ids = new(); | |||||
for (int i = 0; i < trackable_objects.Count; i++) | |||||
{ | |||||
node_ids[trackable_objects[i]] = i; | |||||
} | |||||
var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); | |||||
var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( | |||||
trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, | |||||
saveables_cache); | |||||
CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); | |||||
} | |||||
private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects, | |||||
IDictionary<Trackable, int> node_ids, | |||||
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
slot_variables) | |||||
{ | |||||
TrackableObjectGraph object_graph_proto = new(); | |||||
for (int i = 0; i < trackable_objects.Count; i++) | |||||
{ | |||||
var trackable = trackable_objects[i]; | |||||
Debug.Assert(node_ids[trackable] == i); | |||||
TrackableObjectGraph.Types.TrackableObject object_proto; | |||||
if (slot_variables.TryGetValue(trackable, out var slots)) | |||||
{ | |||||
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||||
} | |||||
else | |||||
{ | |||||
object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||||
} | |||||
object_graph_proto.Nodes.Add(object_proto); | |||||
foreach (var child in graph_view.list_children(trackable)) | |||||
{ | |||||
object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
{ NodeId = node_ids[child.Refer], LocalName = child.Name }); | |||||
} | |||||
} | |||||
return object_graph_proto; | |||||
} | |||||
private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||||
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||||
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||||
bool call_with_mapped_captures, object? saveables_cache = null) | |||||
{ | |||||
int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); | |||||
for (int i = 0; i < cnt; i++) | |||||
{ | |||||
Debug.Assert(node_ids[trackable_objects[i]] == i); | |||||
} | |||||
var (checkpoint_factory_map, unmmaped_registered_savers) = | |||||
get_checkpoint_factories_and_keys(object_names, object_map); | |||||
// skip the process of registered savers | |||||
var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, | |||||
object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); | |||||
return (named_saveable_objects, feed_additions, null); | |||||
} | |||||
public static (List<MySaveableObject>, object?) generate_saveable_objects( | |||||
IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | |||||
TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | |||||
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
{ | |||||
List<MySaveableObject> named_saveable_objects = new(); | |||||
foreach (var pair in checkpoint_factory_map) | |||||
{ | |||||
var trackable = pair.Key; | |||||
var factory_data_list = pair.Value; | |||||
bool fill_object_proto = object_graph_proto is not null && node_ids is not null; | |||||
TrackableObjectGraph.Types.TrackableObject object_proto = null!; | |||||
if (fill_object_proto) | |||||
{ | |||||
object_proto = object_graph_proto.Nodes[node_ids[trackable]]; | |||||
} | |||||
var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
// skip cache | |||||
foreach (var factory_data in factory_data_list) | |||||
{ | |||||
var name = factory_data.name; | |||||
var key = factory_data.checkpoint_key; | |||||
var maybe_saveable = factory_data.factory; | |||||
// TODO: oneflow python has a process with callable `saveable_factory`. | |||||
List<MySaveableObject> saveables = new(); | |||||
if (maybe_saveable.DataType == typeof(MySaveableObject)) | |||||
{ | |||||
saveables.Add(maybe_saveable.GetValueB()); | |||||
} | |||||
else | |||||
{ | |||||
saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); | |||||
} | |||||
foreach (var saveable in saveables) | |||||
{ | |||||
if (!saveable.name.Contains(key)) | |||||
{ | |||||
throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + | |||||
$"'{saveable.name}' for attribute '{name}'. Expected a name" + | |||||
$" containing '{key}'."); | |||||
} | |||||
} | |||||
// skip the process of PythonState | |||||
named_saveable_objects.AddRange(saveables); | |||||
if(!fill_object_proto) continue; | |||||
// skip the process of `TrackableSaveable` because of lack of APIs. | |||||
object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | |||||
} | |||||
} | |||||
return (named_saveable_objects, null); | |||||
} | |||||
} | |||||
public record class CheckpointFactoryData | |||||
( | |||||
Maybe<BaseResourceVariable, MySaveableObject> factory, | |||||
string name, | |||||
string checkpoint_key | |||||
); |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Checkpoint | |||||
{ | |||||
internal static class SaveableCompat | |||||
{ | |||||
public static string? get_saveable_name(Trackable cls_or_obj) | |||||
{ | |||||
// TODO: implement it with Attribute. | |||||
return null; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,82 @@ | |||||
using System; | |||||
using Tensorflow.Train; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
namespace Tensorflow.Checkpoint; | |||||
public class TrackableView | |||||
{ | |||||
protected WeakReference<Trackable> _root_ref; | |||||
public TrackableView(Trackable obj) | |||||
{ | |||||
_root_ref = new WeakReference<Trackable>(obj); | |||||
} | |||||
public TrackableView(WeakReference<Trackable> obj) | |||||
{ | |||||
_root_ref = obj; | |||||
} | |||||
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
obj._maybe_initialize_trackable(); | |||||
Dictionary<string, Trackable> children = new(); | |||||
// Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||||
// Therefore it uses `convert_to_trackable` to have an extra process. | |||||
foreach (var pair in obj._trackable_children(save_type, cache)) | |||||
{ | |||||
children[pair.Key] = pair.Value; | |||||
} | |||||
return children; | |||||
} | |||||
public Trackable Root | |||||
{ | |||||
get | |||||
{ | |||||
if (_root_ref.TryGetTarget(out Trackable res)) | |||||
{ | |||||
return res; | |||||
} | |||||
else | |||||
{ | |||||
throw new InvalidDataException( | |||||
"Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | |||||
/// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | |||||
/// </summary> | |||||
protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||||
{ | |||||
List<Trackable> bfs_sorted = new(); | |||||
Queue<Trackable> to_visit = new(); | |||||
to_visit.Enqueue(Root); | |||||
Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | |||||
node_paths[this.Root] = new List<TrackableReference>(); | |||||
while (!to_visit.empty()) | |||||
{ | |||||
var current_trackable = to_visit.Dequeue(); | |||||
bfs_sorted.Add(current_trackable); | |||||
var children_dict = this.children(current_trackable); | |||||
foreach (var name in children_dict.Keys) | |||||
{ | |||||
var dependency = children_dict[name]; | |||||
if (!node_paths.ContainsKey(dependency)) | |||||
{ | |||||
var list = new List<TrackableReference>(node_paths[current_trackable]); | |||||
list.Add(new TrackableReference(name, dependency)); | |||||
node_paths[dependency] = list; | |||||
to_visit.Enqueue(dependency); | |||||
} | |||||
} | |||||
} | |||||
return (bfs_sorted, node_paths); | |||||
} | |||||
} |
@@ -0,0 +1,195 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Checkpoint; | |||||
/// <summary> | |||||
/// Saves and restores a `Trackable` object and its dependencies. | |||||
/// </summary> | |||||
public class TrackableSaver | |||||
{ | |||||
private ObjectGraphView _graph_view; | |||||
private Tensor _cached_save_operation; | |||||
private TrackableObjectGraph _last_save_object_graph; | |||||
private Tensor? _object_graph_feed_tensor = null; | |||||
private Tensor? _file_prefix_feed_tensor = null; | |||||
private Dictionary<Trackable, Trackable>? _object_map = null; | |||||
private object? _cache = null; | |||||
public TrackableSaver(ObjectGraphView graph_view) | |||||
{ | |||||
_graph_view = graph_view; | |||||
// TODO: cache when not executing eagerly. | |||||
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||||
} | |||||
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||||
{ | |||||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||||
// TODO: cache. | |||||
if(object_graph_tensor is null) | |||||
{ | |||||
tf.device("/cpu:0"); | |||||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
} | |||||
else | |||||
{ | |||||
feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); | |||||
} | |||||
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
if (!serialized_tensors.ContainsKey(Trackable.None)) | |||||
{ | |||||
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||||
} | |||||
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||||
return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||||
} | |||||
private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
{ | |||||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||||
{ | |||||
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
{ | |||||
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
var save_op = saver.save(file_prefix, options); | |||||
// tensorflow python: `with ops.device("/cpu:0"):` | |||||
using (ops.control_dependencies(new object[] { save_op })) | |||||
{ | |||||
_cached_save_operation = array_ops.identity(file_prefix); | |||||
} | |||||
_last_save_object_graph = graph_proto; | |||||
} | |||||
return (_cached_save_operation, feed_additions); | |||||
}; | |||||
if (options.experimental_enable_async_checkpoint) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return run_save(); | |||||
} | |||||
private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
{ | |||||
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||||
{ | |||||
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
{ | |||||
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
var save_op = saver.save(file_prefix, options); | |||||
// tensorflow python: `with ops.device("/cpu:0"):` | |||||
using (ops.control_dependencies(new object[] {save_op} )) | |||||
{ | |||||
_cached_save_operation = array_ops.identity(tf.constant(file_prefix)); | |||||
} | |||||
_last_save_object_graph = graph_proto; | |||||
} | |||||
return (_cached_save_operation, feed_additions); | |||||
}; | |||||
if (options.experimental_enable_async_checkpoint) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return run_save(); | |||||
} | |||||
// TODO: parameter write_done_callback | |||||
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||||
CheckpointOptions? options = null) | |||||
{ | |||||
if (options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
Dictionary<Tensor, object> feed_dict = new(); | |||||
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); | |||||
if (checkpoint_number is not null) | |||||
{ | |||||
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||||
} | |||||
Tensor file_prefix_tensor; | |||||
Tensor object_graph_tensor; | |||||
string file_prefix_to_save; | |||||
if (use_session) | |||||
{ | |||||
if (_object_graph_feed_tensor is null) | |||||
{ | |||||
// In python there is `with ops.device("/cpu:0")`. | |||||
_object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
_file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
} | |||||
object_graph_tensor = _object_graph_feed_tensor; | |||||
file_prefix_tensor = _file_prefix_feed_tensor; | |||||
feed_dict[file_prefix_tensor] = file_prefix; | |||||
file_prefix_to_save = ""; | |||||
} | |||||
else | |||||
{ | |||||
// In python there is `with ops.device("/cpu:0")`. | |||||
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | |||||
object_graph_tensor = null; | |||||
file_prefix_to_save = file_prefix; | |||||
} | |||||
var (save_path, new_feed_additions) = | |||||
save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); | |||||
if (new_feed_additions is not null) | |||||
{ | |||||
foreach (var pair in new_feed_additions) | |||||
{ | |||||
feed_dict.Add(pair.Key, pair.Value); | |||||
} | |||||
} | |||||
if(!use_session) | |||||
{ | |||||
session = null; | |||||
} | |||||
else if (session is null) | |||||
{ | |||||
session = new Session(); // In python it uses `get_session`. | |||||
} | |||||
if (session is not null) | |||||
{ | |||||
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||||
return session.run((Tensor)save_path, s); | |||||
} | |||||
else if (use_session) | |||||
{ | |||||
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||||
"in graph mode without a default session. Please use " + | |||||
"`with tf.Session():` to create a session."); | |||||
} | |||||
else | |||||
{ | |||||
return save_path; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,559 @@ | |||||
using System; | |||||
using System.Buffers.Text; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.ApiDef.Types; | |||||
using static Tensorflow.CostGraphDef.Types; | |||||
using static Tensorflow.OptimizerOptions.Types; | |||||
using static Tensorflow.Binding; | |||||
using System.Text.RegularExpressions; | |||||
using System.Linq; | |||||
using Tensorflow.Operations; | |||||
using Tensorflow.Training; | |||||
using Tensorflow.Graphs; | |||||
using System.Xml.Linq; | |||||
using System.Diagnostics; | |||||
namespace Tensorflow.Checkpoint | |||||
{ | |||||
/// <summary> | |||||
/// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. | |||||
/// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. | |||||
/// </summary> | |||||
public interface IFunctionHolder | |||||
{ | |||||
int ArgCount { get; } | |||||
object DynamicInvoke(params object[] args); | |||||
} | |||||
internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder | |||||
{ | |||||
public int ArgCount => 0; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
public TR Invoke() | |||||
{ | |||||
return Func.Invoke(); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 1; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 2; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder | |||||
{ | |||||
public int ArgCount => 3; | |||||
public object DynamicInvoke(params object[] args) | |||||
{ | |||||
return Func.DynamicInvoke(args); | |||||
} | |||||
} | |||||
public class Maybe<TA, TB> | |||||
{ | |||||
private TA? _valueA = default(TA); | |||||
private TB? _valueB = default(TB); | |||||
private Type _type; | |||||
private bool _assigned = false; | |||||
public Maybe(TA value) | |||||
{ | |||||
_valueA = value; | |||||
_type= typeof(TA); | |||||
_assigned = true; | |||||
} | |||||
public Maybe(TB value) | |||||
{ | |||||
_valueB = value; | |||||
_type = typeof(TB); | |||||
_assigned = true; | |||||
} | |||||
public Type DataType => _type; | |||||
public TA GetValueA() | |||||
{ | |||||
if(!_assigned || DataType != typeof(TA)) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
return _valueA; | |||||
} | |||||
public TB GetValueB() | |||||
{ | |||||
if (!_assigned || DataType != typeof(TB)) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
return _valueB; | |||||
} | |||||
public object GetValue() | |||||
{ | |||||
if (!_assigned) | |||||
{ | |||||
throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
} | |||||
if(DataType == typeof(TA) && _valueA is not null) | |||||
{ | |||||
return _valueA; | |||||
} | |||||
else if(DataType == typeof(TB) && _valueB is not null) | |||||
{ | |||||
return _valueB; | |||||
} | |||||
else if(DataType == typeof(TA)) | |||||
{ | |||||
return _valueA; | |||||
} | |||||
else | |||||
{ | |||||
return _valueB; | |||||
} | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TA a) | |||||
{ | |||||
return new Maybe<TA, TB>(a); | |||||
} | |||||
public static implicit operator Maybe<TA, TB>(TB b) | |||||
{ | |||||
return new Maybe<TA, TB>(b); | |||||
} | |||||
} | |||||
internal class SingleDeviceSaver | |||||
{ | |||||
private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict; | |||||
} | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, Tensor>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
x => x.Key, x => x.Value.ToDictionary( | |||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
} | |||||
public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||||
{ | |||||
_tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||||
x => x.Key, x => x.Value.ToDictionary( | |||||
y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||||
as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||||
} | |||||
public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
List<string> tensor_names = new(); | |||||
List<Tensor> tensors = new(); | |||||
List<string> slice_specs = new(); | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice in tensor_slices) | |||||
{ | |||||
var slice_spec = slice.Key; | |||||
var maybe_tensor = slice.Value; | |||||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
{ | |||||
var spec = maybe_tensor.GetValueB(); | |||||
var tensor_value = spec.tensor; | |||||
if (tensor_value is not null) | |||||
{ | |||||
tensor_names.Add(spec.name); | |||||
tensors.Add(tensor_value); | |||||
slice_specs.Add(spec.slice_spec); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
var tensor = maybe_tensor.GetValueA(); | |||||
tensor_names.Add(checkpoint_key); | |||||
tensors.Add(tensor); | |||||
slice_specs.Add(slice_spec); | |||||
} | |||||
} | |||||
} | |||||
// TODO: specify the device. | |||||
return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); | |||||
} | |||||
public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); | |||||
public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
List<string> tensor_names = new(); | |||||
List<TF_DataType> tensor_dtypes = new(); | |||||
List<string> slice_specs = new(); | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice in tensor_slices) | |||||
{ | |||||
var slice_spec = slice.Key; | |||||
var maybe_tensor = slice.Value; | |||||
// TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||||
if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
{ | |||||
var spec = maybe_tensor.GetValueB(); | |||||
tensor_dtypes.Add(spec.dtype); | |||||
slice_specs.Add(spec.slice_spec); | |||||
tensor_names.Add(spec.name); | |||||
} | |||||
else | |||||
{ | |||||
var tensor = maybe_tensor.GetValueA(); | |||||
tensor_dtypes.Add(tensor.dtype); | |||||
slice_specs.Add(slice_spec); | |||||
tensor_names.Add(checkpoint_key); | |||||
} | |||||
} | |||||
} | |||||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||||
// tf python has code `with ops.device(restore_device):` here. | |||||
tf.device(restore_device); // may be risky. | |||||
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||||
int idx = 0; | |||||
foreach(var pair in _tensor_slice_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var tensor_slices = pair.Value; | |||||
foreach(var slice_spec in tensor_slices.Keys) | |||||
{ | |||||
var restored_tensor = restored_tensors[idx++]; | |||||
if (!restored_tensor_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
restored_tensor_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
} | |||||
restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; | |||||
} | |||||
} | |||||
return restored_tensor_dict; | |||||
} | |||||
public IDictionary<string, IDictionary<string, Tensor>> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); | |||||
} | |||||
/// <summary> | |||||
/// Saves checkpoints directly from multiple devices. | |||||
/// Note that this is a low-level utility which stores Tensors in the keys | |||||
/// specified by `SaveableObject`s.Higher-level utilities for object-based | |||||
/// checkpointing are built on top of it. | |||||
/// </summary> | |||||
public class MultiDeviceSaver | |||||
{ | |||||
private Dictionary<string, SingleDeviceSaver> _single_device_savers; | |||||
private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers; | |||||
private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; | |||||
private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys; | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||||
/// <param name="registered_savers"></param> | |||||
/// <param name="call_with_mapped_capture"></param> | |||||
public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||||
IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||||
{ | |||||
_keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); | |||||
_restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>(); | |||||
Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||||
foreach(var pair in serialized_tensors) | |||||
{ | |||||
var obj = pair.Key; | |||||
var tensor_dict = pair.Value; | |||||
IFunctionHolder restore_fn; | |||||
if(obj == Trackable.None) | |||||
{ | |||||
restore_fn = new FunctionHolder<object?>(() => null); | |||||
} | |||||
else | |||||
{ | |||||
restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x => | |||||
{ | |||||
return obj._restore_from_tensors(x); | |||||
}); | |||||
} | |||||
foreach(var item in tensor_dict) | |||||
{ | |||||
var checkpoint_key = item.Key; | |||||
IDictionary<string, Tensor> spec_to_tensor; | |||||
if(item.Value.DataType != typeof(IDictionary<string, Tensor>)) | |||||
{ | |||||
spec_to_tensor = new Dictionary<string, Tensor>(); | |||||
spec_to_tensor[""] = item.Value.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
spec_to_tensor = item.Value.GetValueB(); | |||||
} | |||||
foreach(var spec in spec_to_tensor) | |||||
{ | |||||
var slice_spec = spec.Key; | |||||
var tensor = spec.Value; | |||||
if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) | |||||
{ | |||||
throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + | |||||
$"slice spec. This is invalid because one will overwrite the " + | |||||
$"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); | |||||
} | |||||
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||||
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||||
// skip the process of device name because lack of API. | |||||
var host_device = tensor.Device; | |||||
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||||
} | |||||
internal_dict[checkpoint_key][slice_spec] = tensor; | |||||
} | |||||
} | |||||
} | |||||
_single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | |||||
_registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>(); | |||||
if(registered_savers is not null && registered_savers.Count > 0) | |||||
{ | |||||
// TODO: complete the implementation. | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
public Operation save(Tensor file_prefix, CheckpointOptions? options= null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
tf.device("CPU"); // may be risky. | |||||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||||
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
Operation save_fn() | |||||
{ | |||||
List<Tensor> saved_prefixes= new(); | |||||
foreach(var saver in _registered_savers) | |||||
{ | |||||
// TODO: implementi it later. | |||||
throw new NotImplementedException(); | |||||
} | |||||
int num_shards = _single_device_savers.Count; | |||||
List<Operation> sharded_saves = new(); | |||||
var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); | |||||
string? last_device = null; | |||||
int shard = 0; | |||||
foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) | |||||
{ | |||||
var device = pair.Key; | |||||
var saver = pair.Value; | |||||
last_device = device; | |||||
// skip the extra process of device name because of lack of API. | |||||
tf.device(device); | |||||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
saved_prefixes.Add(shard_prefix); | |||||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
} | |||||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||||
{ | |||||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||||
tf.device(merge_device); | |||||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
} | |||||
} | |||||
if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) | |||||
{ | |||||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return save_fn(); | |||||
} | |||||
} | |||||
public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); | |||||
public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||||
{ | |||||
if(options is null) | |||||
{ | |||||
options = new CheckpointOptions(); | |||||
} | |||||
IDictionary<string, Operation> restore_func() | |||||
{ | |||||
Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||||
Dictionary<string, Operation> restore_ops = new(); | |||||
foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | |||||
{ | |||||
var device = single_saver.Key; | |||||
var saver = single_saver.Value; | |||||
tf.device(device); | |||||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
foreach(var pair in restored_tensor_dict) | |||||
{ | |||||
var checkpoint_key = pair.Key; | |||||
var slice_and_tensor = pair.Value; | |||||
foreach(var item in slice_and_tensor) | |||||
{ | |||||
var slice_spec = item.Key; | |||||
var tensor = item.Value; | |||||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | |||||
{ | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
Dictionary<string, Tensor> dict = new(); | |||||
dict[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
} | |||||
restore_fn_input_count[restore_fn]--; | |||||
if (restore_fn_input_count[restore_fn] == 0) | |||||
{ | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
foreach(var input in restore_fn_inputs[restore_fn]) | |||||
{ | |||||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
} | |||||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
if(ret is IDictionary<string, Operation>) | |||||
{ | |||||
var dict = (IDictionary<string, Operation>)ret; | |||||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
foreach(var item in _registered_savers) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return restore_ops; | |||||
} | |||||
// TODO: complete the implementation. Currently skip it because of lack of API. | |||||
bool has_custom_device_saver = false; | |||||
if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) | |||||
{ | |||||
// TODO: implement it. Currently `autograph` does not support the function with non parameter. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return restore_func(); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Serializes to a SaverDef referencing the current graph. | |||||
/// </summary> | |||||
public SaverDef to_proto() | |||||
{ | |||||
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); | |||||
var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); | |||||
var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); | |||||
var save_tensor = traced_save_func(filename_tensor); | |||||
var restore_op = traced_restore_func(filename_tensor).op; | |||||
return new SaverDef() | |||||
{ | |||||
FilenameTensorName = filename_tensor.name, | |||||
SaveTensorName = save_tensor.name, | |||||
RestoreOpName = restore_op.name, | |||||
Version = SaverDef.Types.CheckpointFormatVersion.V2 | |||||
}; | |||||
} | |||||
private Tensor _traced_save(Tensor file_prefix) | |||||
{ | |||||
var save_op = save(file_prefix); | |||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(new object[]{ save_op })) | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
} | |||||
} | |||||
private Tensor _traced_restore(Tensor file_prefix) | |||||
{ | |||||
var restore_op = restore(file_prefix); | |||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(restore_op.Values.ToArray())) | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
} | |||||
} | |||||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||||
{ | |||||
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||||
foreach (var saveable in saveables) | |||||
{ | |||||
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||||
serialized_tensors[trackable] = trackable.serialize_to_tensors(); | |||||
} | |||||
return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); | |||||
} | |||||
private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) | |||||
{ | |||||
return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); | |||||
} | |||||
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) | |||||
{ | |||||
return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); | |||||
} | |||||
} | |||||
} |
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
using Tensorflow.Train; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -90,4 +91,71 @@ namespace Tensorflow | |||||
Dispose(false); | Dispose(false); | ||||
} | } | ||||
} | } | ||||
public abstract class DisposableTrackableObject: Trackable, IDisposable | |||||
{ | |||||
protected IntPtr _handle; | |||||
protected bool _disposed; | |||||
protected DisposableTrackableObject() | |||||
{ } | |||||
protected DisposableTrackableObject(IntPtr handle) | |||||
=> _handle = handle; | |||||
private void Dispose(bool disposing) | |||||
{ | |||||
if (_disposed) | |||||
return; | |||||
//first handle managed, they might use the unmanaged resources. | |||||
if (disposing) | |||||
{ | |||||
// dispose managed state (managed objects). | |||||
DisposeManagedResources(); | |||||
} | |||||
// free unmanaged memory | |||||
if (_handle != IntPtr.Zero) | |||||
{ | |||||
// Call the appropriate methods to clean up | |||||
// unmanaged resources here. | |||||
// If disposing is false, | |||||
// only the following code is executed. | |||||
DisposeUnmanagedResources(_handle); | |||||
_handle = IntPtr.Zero; | |||||
} | |||||
// Note disposing has been done. | |||||
_disposed = true; | |||||
} | |||||
/// <summary> | |||||
/// Dispose any managed resources. | |||||
/// </summary> | |||||
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||||
protected virtual void DisposeManagedResources() | |||||
{ } | |||||
/// <summary> | |||||
/// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
/// </summary> | |||||
protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||||
public void Dispose() | |||||
{ | |||||
Dispose(true); | |||||
// This object will be cleaned up by the Dispose method. | |||||
// Therefore, you should call GC.SupressFinalize to | |||||
// take this object off the finalization queue | |||||
// and prevent finalization code for this object | |||||
// from executing a second time. | |||||
GC.SuppressFinalize(this); | |||||
} | |||||
~DisposableTrackableObject() | |||||
{ | |||||
Dispose(false); | |||||
} | |||||
} | |||||
} | } |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Xml.Linq; | |||||
using Tensorflow.Contexts; | |||||
using static Tensorflow.ApiDef.Types; | |||||
using static Tensorflow.CostGraphDef.Types; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
internal class execute | |||||
{ | |||||
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||||
{ | |||||
var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); | |||||
var types = v.Select(t => t.dtype.as_datatype_enum()); | |||||
return (types.ToArray(), v.ToArray()); | |||||
} | |||||
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||||
{ | |||||
string device_name = ctx.DeviceName; | |||||
ctx.ensure_initialized(); | |||||
var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); | |||||
return tensors; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
namespace Tensorflow.Exceptions; | |||||
public class AssertionError : TensorflowException | |||||
{ | |||||
public AssertionError() : base() | |||||
{ | |||||
} | |||||
public AssertionError(string message) : base(message) | |||||
{ | |||||
} | |||||
} |
@@ -304,7 +304,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
public static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
{ | { | ||||
var used_ops = ops_used_by_graph_def(graph_def); | var used_ops = ops_used_by_graph_def(graph_def); | ||||
@@ -345,5 +345,89 @@ namespace Tensorflow | |||||
return used_ops.ToArray(); | return used_ops.ToArray(); | ||||
} | } | ||||
private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) | |||||
{ | |||||
foreach (var attr_def in op_def.Attr) | |||||
{ | |||||
if (attr_def.Name == attr_name) | |||||
{ | |||||
if (attr_def.DefaultValue is null) return false; | |||||
// TODO: add new c_api `EqualAttrValueWrapper` and complete the check. | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) | |||||
{ | |||||
Dictionary<string, FunctionDef> op_name_to_function = new(); | |||||
foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
{ | |||||
op_name_to_function[function_def.Signature.Name] = function_def; | |||||
} | |||||
Action<NodeDef> _strip_node_default_valued_attrs = (node_def) => | |||||
{ | |||||
if (op_name_to_function.ContainsKey(node_def.Op)) return; | |||||
var op_def = op_def_registry.GetOpDef(node_def.Op); | |||||
if(op_def is null) return; | |||||
HashSet<string> attrs_to_strip = new(); | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
if (is_default_attr_value(op_def, attr.Key, attr.Value)) | |||||
{ | |||||
attrs_to_strip.Add(attr.Key); | |||||
} | |||||
} | |||||
foreach (var attr in attrs_to_strip) | |||||
{ | |||||
node_def.Attr.Remove(attr); | |||||
} | |||||
}; | |||||
foreach (var node_def in meta_graph_def.GraphDef.Node) | |||||
{ | |||||
_strip_node_default_valued_attrs(node_def); | |||||
} | |||||
foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
{ | |||||
foreach (var function_node_def in function_def.NodeDef) | |||||
{ | |||||
_strip_node_default_valued_attrs(function_node_def); | |||||
} | |||||
} | |||||
meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||||
} | |||||
/// <summary> | |||||
/// Extract the Op name from a Tensor name. | |||||
/// </summary> | |||||
/// <param name="tensor_name"></param> | |||||
/// <returns></returns> | |||||
public static string op_name(string tensor_name) | |||||
{ | |||||
if (string.IsNullOrEmpty(tensor_name)) | |||||
{ | |||||
throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); | |||||
} | |||||
if (tensor_name.StartsWith("^")) | |||||
{ | |||||
tensor_name = tensor_name.Substring(1); | |||||
} | |||||
if (tensor_name.Contains(":")) | |||||
{ | |||||
return tensor_name.Split(':')[0]; | |||||
} | |||||
return tensor_name; | |||||
} | |||||
} | } | ||||
} | } |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
@@ -10,7 +11,7 @@ namespace Tensorflow.Functions | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
public class ConcreteFunction | |||||
public class ConcreteFunction: Trackable | |||||
{ | { | ||||
FuncGraph func_graph; | FuncGraph func_graph; | ||||
ForwardBackwardCall forward_backward; | ForwardBackwardCall forward_backward; | ||||
@@ -1,16 +1,23 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Train; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class Function | |||||
public class Function: Trackable | |||||
{ | { | ||||
#pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
#pragma warning restore CS0169 // The field 'Function._handle' is never used | #pragma warning restore CS0169 // The field 'Function._handle' is never used | ||||
public string Name { get; set; } | |||||
public Function() | public Function() | ||||
{ | { | ||||
} | } | ||||
public Function(string name) | |||||
{ | |||||
Name = name; | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System; | using System; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -6,14 +7,14 @@ namespace Tensorflow.Graphs | |||||
{ | { | ||||
public class AutoGraph | public class AutoGraph | ||||
{ | { | ||||
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | |||||
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32) | |||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | ||||
var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
graph.as_default(); | graph.as_default(); | ||||
var input = tf.placeholder(tf.int32); | |||||
var input = tf.placeholder(dtype); | |||||
var output = func(input); | var output = func(input); | ||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
@@ -26,25 +27,33 @@ namespace Tensorflow.Graphs | |||||
return (Tensor input) => | return (Tensor input) => | ||||
{ | { | ||||
var result = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | |||||
func_name, | |||||
new[] { input }, | |||||
null, | |||||
1); | |||||
return result[0]; | |||||
if (tf.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | |||||
func_name, | |||||
new[] { input }, | |||||
null, | |||||
1); | |||||
return result[0]; | |||||
} | |||||
using (var s = tf.Session(input.graph)) | |||||
{ | |||||
var output = func(input); | |||||
return output; | |||||
} | |||||
}; | }; | ||||
} | } | ||||
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | |||||
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes) | |||||
{ | { | ||||
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | ||||
var graph = new FuncGraph(func_name); | var graph = new FuncGraph(func_name); | ||||
graph.as_default(); | graph.as_default(); | ||||
var input1 = tf.placeholder(tf.int32); | |||||
var input2 = tf.placeholder(tf.int32); | |||||
var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); | |||||
var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); | |||||
var output = func(input1, input2); | var output = func(input1, input2); | ||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
@@ -56,13 +65,22 @@ namespace Tensorflow.Graphs | |||||
return (Tensor a, Tensor b) => | return (Tensor a, Tensor b) => | ||||
{ | { | ||||
var result = tf.Runner.TFE_Execute(tf.Context, | |||||
if (tf.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | tf.Context.DeviceName, | ||||
func_name, | func_name, | ||||
new[] { a, b }, | new[] { a, b }, | ||||
null, | null, | ||||
1); | 1); | ||||
return result[0]; | |||||
return result[0]; | |||||
} | |||||
using (var s = tf.Session(a.graph)) | |||||
{ | |||||
Debug.Assert(a.graph == b.graph); | |||||
var output = func(a, b); | |||||
return output; | |||||
} | |||||
}; | }; | ||||
} | } | ||||
} | } | ||||
@@ -1,9 +1,18 @@ | |||||
using System; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Keras.ArgsDefinition { | namespace Tensorflow.Keras.ArgsDefinition { | ||||
public class SoftmaxArgs : LayerArgs { | |||||
public Axis axis { get; set; } = -1; | |||||
} | |||||
public class SoftmaxArgs : LayerArgs | |||||
{ | |||||
[JsonProperty("axis")] | |||||
public Axis axis { get; set; } = -1; | |||||
[JsonProperty("name")] | |||||
public override string Name { get => base.Name; set => base.Name = value; } | |||||
[JsonProperty("trainable")] | |||||
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
[JsonProperty("dtype")] | |||||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
} | |||||
} | } |
@@ -0,0 +1,19 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class AutoSerializeLayerArgs: LayerArgs | |||||
{ | |||||
[JsonProperty("name")] | |||||
public override string Name { get => base.Name; set => base.Name = value; } | |||||
[JsonProperty("dtype")] | |||||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
[JsonProperty("trainable")] | |||||
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
} | |||||
} |
@@ -1,13 +1,18 @@ | |||||
using System; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Xml.Linq; | |||||
using Tensorflow.Operations.Initializers; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
// TODO: `activity_regularizer` | |||||
public class DenseArgs : LayerArgs | public class DenseArgs : LayerArgs | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Positive integer, dimensionality of the output space. | /// Positive integer, dimensionality of the output space. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | public int Units { get; set; } | ||||
/// <summary> | /// <summary> | ||||
@@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
/// </summary> | /// </summary> | ||||
public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
private string _activationName; | |||||
[JsonProperty("activation")] | |||||
public string ActivationName | |||||
{ | |||||
get | |||||
{ | |||||
if (string.IsNullOrEmpty(_activationName)) | |||||
{ | |||||
return Activation.Method.Name; | |||||
} | |||||
else | |||||
{ | |||||
return _activationName; | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
_activationName = value; | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Whether the layer uses a bias vector. | /// Whether the layer uses a bias vector. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
/// <summary> | /// <summary> | ||||
/// Initializer for the `kernel` weights matrix. | /// Initializer for the `kernel` weights matrix. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
/// <summary> | /// <summary> | ||||
/// Initializer for the bias vector. | /// Initializer for the bias vector. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
/// <summary> | /// <summary> | ||||
/// Regularizer function applied to the `kernel` weights matrix. | /// Regularizer function applied to the `kernel` weights matrix. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("kernel_regularizer")] | |||||
public IRegularizer KernelRegularizer { get; set; } | public IRegularizer KernelRegularizer { get; set; } | ||||
/// <summary> | /// <summary> | ||||
/// Regularizer function applied to the bias vector. | /// Regularizer function applied to the bias vector. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("bias_regularizer")] | |||||
public IRegularizer BiasRegularizer { get; set; } | public IRegularizer BiasRegularizer { get; set; } | ||||
/// <summary> | /// <summary> | ||||
/// Constraint function applied to the `kernel` weights matrix. | /// Constraint function applied to the `kernel` weights matrix. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("kernel_constraint")] | |||||
public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
/// <summary> | /// <summary> | ||||
/// Constraint function applied to the bias vector. | /// Constraint function applied to the bias vector. | ||||
/// </summary> | /// </summary> | ||||
[JsonProperty("bias_constraint")] | |||||
public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
[JsonProperty("name")] | |||||
public override string Name { get => base.Name; set => base.Name = value; } | |||||
[JsonProperty("dtype")] | |||||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
[JsonProperty("trainable")] | |||||
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
} | } | ||||
} | } |
@@ -1,9 +1,22 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
using Newtonsoft.Json; | |||||
using Newtonsoft.Json.Serialization; | |||||
using Tensorflow.Keras.Common; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class InputLayerArgs : LayerArgs | public class InputLayerArgs : LayerArgs | ||||
{ | { | ||||
[JsonIgnore] | |||||
public Tensor InputTensor { get; set; } | public Tensor InputTensor { get; set; } | ||||
public bool Sparse { get; set; } | |||||
[JsonProperty("sparse")] | |||||
public virtual bool Sparse { get; set; } | |||||
[JsonProperty("ragged")] | |||||
public bool Ragged { get; set; } | public bool Ragged { get; set; } | ||||
[JsonProperty("name")] | |||||
public override string Name { get => base.Name; set => base.Name = value; } | |||||
[JsonProperty("dtype")] | |||||
public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
} | } | ||||
} | } |
@@ -1,8 +1,9 @@ | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class DataAdapterArgs | |||||
public class DataAdapterArgs: IKerasConfig | |||||
{ | { | ||||
public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
@@ -1,8 +1,9 @@ | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class DataHandlerArgs | |||||
public class DataHandlerArgs: IKerasConfig | |||||
{ | { | ||||
public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
@@ -1,51 +1,54 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
using Newtonsoft.Json; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class LayerArgs | |||||
[JsonObject(MemberSerialization.OptIn)] | |||||
public class LayerArgs: IKerasConfig | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Indicates whether the layer's weights are updated during training | /// Indicates whether the layer's weights are updated during training | ||||
/// and whether the layer's updates are run during training. | /// and whether the layer's updates are run during training. | ||||
/// </summary> | /// </summary> | ||||
public bool Trainable { get; set; } = true; | |||||
public string Name { get; set; } | |||||
public virtual bool Trainable { get; set; } = true; | |||||
public virtual string Name { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Only applicable to input layers. | /// Only applicable to input layers. | ||||
/// </summary> | /// </summary> | ||||
public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
/// <summary> | /// <summary> | ||||
/// Whether the `call` method can be used to build a TF graph without issues. | /// Whether the `call` method can be used to build a TF graph without issues. | ||||
/// This attribute has no effect if the model is created using the Functional | /// This attribute has no effect if the model is created using the Functional | ||||
/// API. Instead, `model.dynamic` is determined based on the internal layers. | /// API. Instead, `model.dynamic` is determined based on the internal layers. | ||||
/// </summary> | /// </summary> | ||||
public bool Dynamic { get; set; } = false; | |||||
public virtual bool Dynamic { get; set; } = false; | |||||
/// <summary> | /// <summary> | ||||
/// Only applicable to input layers. | /// Only applicable to input layers. | ||||
/// </summary> | /// </summary> | ||||
public Shape InputShape { get; set; } | |||||
public virtual Shape InputShape { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Only applicable to input layers. | /// Only applicable to input layers. | ||||
/// </summary> | /// </summary> | ||||
public Shape BatchInputShape { get; set; } | |||||
public virtual Shape BatchInputShape { get; set; } | |||||
public int BatchSize { get; set; } = -1; | |||||
public virtual int BatchSize { get; set; } = -1; | |||||
/// <summary> | /// <summary> | ||||
/// Initial weight values. | /// Initial weight values. | ||||
/// </summary> | /// </summary> | ||||
public float[] Weights { get; set; } | |||||
public virtual float[] Weights { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Regularizer function applied to the output of the layer(its "activation"). | /// Regularizer function applied to the output of the layer(its "activation"). | ||||
/// </summary> | /// </summary> | ||||
public IRegularizer ActivityRegularizer { get; set; } | |||||
public virtual IRegularizer ActivityRegularizer { get; set; } | |||||
public bool Autocast { get; set; } | |||||
public virtual bool Autocast { get; set; } | |||||
public bool IsFromConfig { get; set; } | |||||
public virtual bool IsFromConfig { get; set; } | |||||
} | } | ||||
} | } |
@@ -1,6 +1,8 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class NodeArgs | |||||
public class NodeArgs: IKerasConfig | |||||
{ | { | ||||
public ILayer[] InboundLayers { get; set; } | public ILayer[] InboundLayers { get; set; } | ||||
public int[] NodeIndices { get; set; } | public int[] NodeIndices { get; set; } | ||||
@@ -1,6 +1,8 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class OptimizerV2Args | |||||
public class OptimizerV2Args: IKerasConfig | |||||
{ | { | ||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public float LearningRate { get; set; } = 0.001f; | public float LearningRate { get; set; } = 0.001f; | ||||
@@ -1,7 +1,10 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
using Newtonsoft.Json; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class FlattenArgs : LayerArgs | |||||
public class FlattenArgs : AutoSerializeLayerArgs | |||||
{ | { | ||||
[JsonProperty("data_format")] | |||||
public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,50 @@ | |||||
using Newtonsoft.Json; | |||||
using Newtonsoft.Json.Converters; | |||||
using Newtonsoft.Json.Linq; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Common | |||||
{ | |||||
public class CustomizedActivationJsonConverter : JsonConverter | |||||
{ | |||||
public override bool CanConvert(Type objectType) | |||||
{ | |||||
return objectType == typeof(Activation); | |||||
} | |||||
public override bool CanRead => true; | |||||
public override bool CanWrite => true; | |||||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
{ | |||||
if (value is null) | |||||
{ | |||||
var token = JToken.FromObject(""); | |||||
token.WriteTo(writer); | |||||
} | |||||
else if (value is not Activation) | |||||
{ | |||||
throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); | |||||
} | |||||
else | |||||
{ | |||||
var token = JToken.FromObject((value as Activation)!.GetType().Name); | |||||
token.WriteTo(writer); | |||||
} | |||||
} | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
//var dims = serializer.Deserialize(reader, typeof(string)); | |||||
//if (dims is null) | |||||
//{ | |||||
// throw new ValueError("Cannot deserialize 'null' to `Activation`."); | |||||
//} | |||||
//return new Shape((long[])(dims!)); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,48 @@ | |||||
using Newtonsoft.Json.Linq; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Common | |||||
{ | |||||
public class CustomizedAxisJsonConverter : JsonConverter | |||||
{ | |||||
public override bool CanConvert(Type objectType) | |||||
{ | |||||
return objectType == typeof(Axis); | |||||
} | |||||
public override bool CanRead => true; | |||||
public override bool CanWrite => true; | |||||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
{ | |||||
if (value is null) | |||||
{ | |||||
var token = JToken.FromObject(new int[] { }); | |||||
token.WriteTo(writer); | |||||
} | |||||
else if (value is not Axis) | |||||
{ | |||||
throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); | |||||
} | |||||
else | |||||
{ | |||||
var token = JToken.FromObject((value as Axis)!.axis); | |||||
token.WriteTo(writer); | |||||
} | |||||
} | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
{ | |||||
var axis = serializer.Deserialize(reader, typeof(long[])); | |||||
if (axis is null) | |||||
{ | |||||
throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||||
} | |||||
return new Axis((int[])(axis!)); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,73 @@ | |||||
using Newtonsoft.Json; | |||||
using Newtonsoft.Json.Converters; | |||||
using Newtonsoft.Json.Linq; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Common | |||||
{ | |||||
public class CustomizedNodeConfigJsonConverter : JsonConverter | |||||
{ | |||||
public override bool CanConvert(Type objectType) | |||||
{ | |||||
return objectType == typeof(NodeConfig); | |||||
} | |||||
public override bool CanRead => true; | |||||
public override bool CanWrite => true; | |||||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
{ | |||||
if (value is null) | |||||
{ | |||||
var token = JToken.FromObject(null); | |||||
token.WriteTo(writer); | |||||
} | |||||
else if (value is not NodeConfig) | |||||
{ | |||||
throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); | |||||
} | |||||
else | |||||
{ | |||||
var config = value as NodeConfig; | |||||
var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); | |||||
token.WriteTo(writer); | |||||
} | |||||
} | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
{ | |||||
var values = serializer.Deserialize(reader, typeof(object[])) as object[]; | |||||
if (values is null) | |||||
{ | |||||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
} | |||||
if(values.Length != 3) | |||||
{ | |||||
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||||
} | |||||
if (values[0] is not string) | |||||
{ | |||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||||
} | |||||
if (values[1] is not int) | |||||
{ | |||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||||
} | |||||
if (values[2] is not int) | |||||
{ | |||||
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||||
} | |||||
return new NodeConfig() | |||||
{ | |||||
Name = values[0] as string, | |||||
NodeIndex = (int)values[1], | |||||
TensorIndex = (int)values[2] | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,67 @@ | |||||
using Newtonsoft.Json; | |||||
using Newtonsoft.Json.Converters; | |||||
using Newtonsoft.Json.Linq; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Common | |||||
{ | |||||
public class CustomizedShapeJsonConverter: JsonConverter | |||||
{ | |||||
public override bool CanConvert(Type objectType) | |||||
{ | |||||
return objectType == typeof(Shape); | |||||
} | |||||
public override bool CanRead => true; | |||||
public override bool CanWrite => true; | |||||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
{ | |||||
if(value is null) | |||||
{ | |||||
var token = JToken.FromObject(null); | |||||
token.WriteTo(writer); | |||||
} | |||||
else if(value is not Shape) | |||||
{ | |||||
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | |||||
} | |||||
else | |||||
{ | |||||
var shape = (value as Shape)!; | |||||
long?[] dims = new long?[shape.ndim]; | |||||
for(int i = 0; i < dims.Length; i++) | |||||
{ | |||||
if (shape.dims[i] == -1) | |||||
{ | |||||
dims[i] = null; | |||||
} | |||||
else | |||||
{ | |||||
dims[i] = shape.dims[i]; | |||||
} | |||||
} | |||||
var token = JToken.FromObject(dims); | |||||
token.WriteTo(writer); | |||||
} | |||||
} | |||||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
{ | |||||
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
if(dims is null) | |||||
{ | |||||
throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
} | |||||
long[] convertedDims = new long[dims.Length]; | |||||
for(int i = 0; i < dims.Length; i++) | |||||
{ | |||||
convertedDims[i] = dims[i] ?? (-1); | |||||
} | |||||
return new Shape(convertedDims); | |||||
} | |||||
} | |||||
} |
@@ -16,23 +16,27 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Specifies the ndim, dtype and shape of every input to a layer. | /// Specifies the ndim, dtype and shape of every input to a layer. | ||||
/// </summary> | /// </summary> | ||||
public class InputSpec | |||||
public class InputSpec: IKerasConfigable | |||||
{ | { | ||||
public int? ndim; | public int? ndim; | ||||
public int? max_ndim; | |||||
public int? min_ndim; | public int? min_ndim; | ||||
Dictionary<int, int> axes; | Dictionary<int, int> axes; | ||||
Shape shape; | Shape shape; | ||||
TF_DataType dtype; | |||||
public int[] AllAxisDim; | public int[] AllAxisDim; | ||||
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int? ndim = null, | int? ndim = null, | ||||
int? min_ndim = null, | int? min_ndim = null, | ||||
int? max_ndim = null, | |||||
Dictionary<int, int> axes = null, | Dictionary<int, int> axes = null, | ||||
Shape shape = null) | Shape shape = null) | ||||
{ | { | ||||
@@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine | |||||
axes = new Dictionary<int, int>(); | axes = new Dictionary<int, int>(); | ||||
this.axes = axes; | this.axes = axes; | ||||
this.min_ndim = min_ndim; | this.min_ndim = min_ndim; | ||||
this.max_ndim = max_ndim; | |||||
this.shape = shape; | this.shape = shape; | ||||
this.dtype = dtype; | |||||
if (ndim == null && shape != null) | if (ndim == null && shape != null) | ||||
this.ndim = shape.ndim; | this.ndim = shape.ndim; | ||||
@@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine | |||||
AllAxisDim = axes.Select(x => x.Value).ToArray(); | AllAxisDim = axes.Select(x => x.Value).ToArray(); | ||||
} | } | ||||
public IKerasConfig get_config() | |||||
{ | |||||
return new Config() | |||||
{ | |||||
DType = dtype == TF_DataType.DtInvalid ? null : dtype, | |||||
Shape = shape, | |||||
Ndim = ndim, | |||||
MinNdim = min_ndim, | |||||
MaxNdim = max_ndim, | |||||
Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) | |||||
}; | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | ||||
public class Config: IKerasConfig | |||||
{ | |||||
public TF_DataType? DType { get; set; } | |||||
public Shape Shape { get; set; } | |||||
public int? Ndim { get; set; } | |||||
public int? MinNdim { get;set; } | |||||
public int? MaxNdim { get;set; } | |||||
public IDictionary<string, int> Axes { get; set; } | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,10 +1,12 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Training; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
public interface ILayer | |||||
public interface ILayer: IWithTrackable, IKerasConfigable | |||||
{ | { | ||||
string Name { get; } | string Name { get; } | ||||
bool Trainable { get; } | bool Trainable { get; } | ||||
@@ -19,8 +21,8 @@ namespace Tensorflow.Keras | |||||
List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
TensorShapeConfig BuildInputShape { get; } | |||||
TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
int count_params(); | int count_params(); | ||||
LayerArgs get_config(); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public interface IKerasConfig | |||||
{ | |||||
} | |||||
public interface IKerasConfigable | |||||
{ | |||||
IKerasConfig get_config(); | |||||
} | |||||
} |
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
@@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine; | |||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
public class LayerConfig | |||||
public class LayerConfig: IKerasConfig | |||||
{ | { | ||||
[JsonProperty("name")] | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
[JsonProperty("class_name")] | |||||
public string ClassName { get; set; } | public string ClassName { get; set; } | ||||
[JsonProperty("config")] | |||||
public LayerArgs Config { get; set; } | public LayerArgs Config { get; set; } | ||||
[JsonProperty("inbound_nodes")] | |||||
public List<NodeConfig> InboundNodes { get; set; } | public List<NodeConfig> InboundNodes { get; set; } | ||||
} | } | ||||
} | } |
@@ -1,15 +1,20 @@ | |||||
using System; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
public class ModelConfig | |||||
public class ModelConfig : IKerasConfig | |||||
{ | { | ||||
[JsonProperty("name")] | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
[JsonProperty("layers")] | |||||
public List<LayerConfig> Layers { get; set; } | public List<LayerConfig> Layers { get; set; } | ||||
[JsonProperty("input_layers")] | |||||
public List<NodeConfig> InputLayers { get; set; } | public List<NodeConfig> InputLayers { get; set; } | ||||
[JsonProperty("output_layers")] | |||||
public List<NodeConfig> OutputLayers { get; set; } | public List<NodeConfig> OutputLayers { get; set; } | ||||
public override string ToString() | public override string ToString() | ||||
@@ -1,10 +1,13 @@ | |||||
using System; | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Common; | |||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
public class NodeConfig | |||||
[JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] | |||||
public class NodeConfig : IKerasConfig | |||||
{ | { | ||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public int NodeIndex { get; set; } | public int NodeIndex { get; set; } | ||||
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | |||||
public interface ISerializedAttributes | |||||
{ | |||||
IDictionary<string, Trackable> Functions { get; } | |||||
IDictionary<string, Trackable> CheckpointableObjects { get; } | |||||
/// <summary> | |||||
/// Returns functions to attach to the root object during serialization. | |||||
/// </summary> | |||||
IDictionary<string, Trackable> FunctionsToSerialize { get; } | |||||
/// <summary> | |||||
/// Returns objects to attach to the root object during serialization. | |||||
/// </summary> | |||||
IDictionary<string, Trackable> ObjectsToSerialize{get; } | |||||
/// <summary> | |||||
/// Saves function dictionary, and validates dictionary values. | |||||
/// </summary> | |||||
/// <param name="function_dict"></param> | |||||
IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict); | |||||
/// <summary> | |||||
/// Saves objects to a dictionary, and validates the values. | |||||
/// </summary> | |||||
/// <param name="object_dict"></param> | |||||
IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict); | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class TensorShapeConfig | |||||
{ | |||||
[JsonProperty("class_name")] | |||||
public string ClassName { get; set; } = "TensorShape"; | |||||
[JsonProperty("items")] | |||||
public long?[] Items { get; set; } | |||||
public static implicit operator Shape(TensorShapeConfig shape) | |||||
=> shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
public static implicit operator TensorShapeConfig(Shape shape) | |||||
=> new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }; | |||||
} | |||||
} |
@@ -9,10 +9,52 @@ namespace Tensorflow.ModelSaving | |||||
/// </summary> | /// </summary> | ||||
public class SaveOptions | public class SaveOptions | ||||
{ | { | ||||
bool save_debug_info; | |||||
public bool save_debug_info = false; | |||||
public IList<string>? namespace_white_list { get; set; } = null; | |||||
public IDictionary<string, object>? function_aliases { get; set; } = null; | |||||
public string? experimental_io_device { get; set; } = null; | |||||
// TODO: experimental | |||||
public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; | |||||
public bool experimental_custom_gradients { get; set; } = true; | |||||
public SaveOptions(bool save_debug_info = false) | public SaveOptions(bool save_debug_info = false) | ||||
{ | { | ||||
this.save_debug_info = save_debug_info; | this.save_debug_info = save_debug_info; | ||||
} | } | ||||
} | } | ||||
public class VariablePolicy | |||||
{ | |||||
public string Policy { get; } | |||||
private VariablePolicy(string policy) | |||||
{ | |||||
Policy = policy; | |||||
} | |||||
public static VariablePolicy None = new(null); | |||||
public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); | |||||
public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); | |||||
public bool save_variable_devices() | |||||
{ | |||||
return this != VariablePolicy.None; | |||||
} | |||||
/// <summary> | |||||
/// Tries to convert `obj` to a VariablePolicy instance. | |||||
/// </summary> | |||||
/// <param name="obj"></param> | |||||
/// <returns></returns> | |||||
public static VariablePolicy from_obj(object obj) | |||||
{ | |||||
if (obj is null) return VariablePolicy.None; | |||||
if (obj is VariablePolicy) return (VariablePolicy)obj; | |||||
var key = obj.ToString().ToLower(); | |||||
return key switch | |||||
{ | |||||
null => VariablePolicy.None, | |||||
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||||
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||||
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||||
}; | |||||
} | |||||
} | |||||
} | } |
@@ -14,20 +14,29 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Newtonsoft.Json; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Common; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public record Axis(params int[] axis) | |||||
[JsonConverter(typeof(CustomizedAxisJsonConverter))] | |||||
public class Axis | |||||
{ | { | ||||
public int[] axis { get; set; } | |||||
public int size => axis == null ? -1 : axis.Length; | public int size => axis == null ? -1 : axis.Length; | ||||
public bool IsScalar { get; init; } | public bool IsScalar { get; init; } | ||||
public int this[int index] => axis[index]; | public int this[int index] => axis[index]; | ||||
public Axis(params int[] axis) | |||||
{ | |||||
this.axis = axis; | |||||
} | |||||
public static implicit operator int[]?(Axis axis) | public static implicit operator int[]?(Axis axis) | ||||
=> axis?.axis; | => axis?.axis; | ||||
@@ -14,14 +14,17 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Newtonsoft.Json; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Common; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||||
public class Shape | public class Shape | ||||
{ | { | ||||
public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class Constant<T> : IInitializer | public class Constant<T> : IInitializer | ||||
@@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers | |||||
T value; | T value; | ||||
bool _verify_shape; | bool _verify_shape; | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "Constant"; | |||||
public IDictionary<string, object> Config => _config; | |||||
public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | ||||
{ | { | ||||
this.value = value; | this.value = value; | ||||
this.dtype = dtype; | this.dtype = dtype; | ||||
_verify_shape = verify_shape; | _verify_shape = verify_shape; | ||||
_config = new Dictionary<string, object>(); | |||||
_config["value"] = this.value; | |||||
} | } | ||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
@@ -14,10 +14,17 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class GlorotUniform : VarianceScaling | public class GlorotUniform : VarianceScaling | ||||
{ | { | ||||
private readonly Dictionary<string, object> _config; | |||||
public override string ClassName => "GlorotUniform"; | |||||
public override IDictionary<string, object> Config => _config; | |||||
public GlorotUniform(float scale = 1.0f, | public GlorotUniform(float scale = 1.0f, | ||||
string mode = "FAN_AVG", | string mode = "FAN_AVG", | ||||
bool uniform = true, | bool uniform = true, | ||||
@@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers | |||||
seed: seed, | seed: seed, | ||||
dtype: dtype) | dtype: dtype) | ||||
{ | { | ||||
_config = new Dictionary<string, object>(); | |||||
_config["seed"] = _seed; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -14,10 +14,17 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Newtonsoft.Json; | |||||
using System.Collections.Generic; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public interface IInitializer | public interface IInitializer | ||||
{ | { | ||||
[JsonProperty("class_name")] | |||||
string ClassName { get; } | |||||
[JsonProperty("config")] | |||||
IDictionary<string, object> Config { get; } | |||||
Tensor Apply(InitializerArgs args); | Tensor Apply(InitializerArgs args); | ||||
} | } | ||||
} | } |
@@ -14,12 +14,19 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class Ones : IInitializer | public class Ones : IInitializer | ||||
{ | { | ||||
private TF_DataType dtype; | private TF_DataType dtype; | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "Ones"; | |||||
public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
{ | { | ||||
this.dtype = dtype; | this.dtype = dtype; | ||||
@@ -1,4 +1,4 @@ | |||||
/***************************************************************************** | |||||
/***************************************************************************** | |||||
Copyright 2023 Haiping Chen. All Rights Reserved. | Copyright 2023 Haiping Chen. All Rights Reserved. | ||||
Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -19,6 +19,7 @@ using System.Linq; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Operations.Initializers; | namespace Tensorflow.Operations.Initializers; | ||||
using System.Collections.Generic; | |||||
public class Orthogonal : IInitializer | public class Orthogonal : IInitializer | ||||
{ | { | ||||
@@ -31,6 +32,10 @@ public class Orthogonal : IInitializer | |||||
_seed = seed; | _seed = seed; | ||||
} | } | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "Orthogonal"; | |||||
public IDictionary<string, object> Config => throw new NotImplementedException(); | |||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
{ | { | ||||
return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); | return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class RandomNormal : IInitializer | public class RandomNormal : IInitializer | ||||
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
private int? seed; | private int? seed; | ||||
private TF_DataType dtype; | private TF_DataType dtype; | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "RandomNormal"; | |||||
public IDictionary<string, object> Config => _config; | |||||
public RandomNormal(float mean = 0.0f, | public RandomNormal(float mean = 0.0f, | ||||
float stddev = 0.05f, | float stddev = 0.05f, | ||||
int? seed = null, | int? seed = null, | ||||
@@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers | |||||
this.stddev = stddev; | this.stddev = stddev; | ||||
this.seed = seed; | this.seed = seed; | ||||
this.dtype = dtype; | this.dtype = dtype; | ||||
_config = new Dictionary<string, object>(); | |||||
_config["mean"] = this.mean; | |||||
_config["stddev"] = this.stddev; | |||||
_config["seed"] = this.seed; | |||||
} | } | ||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class RandomUniform : IInitializer | public class RandomUniform : IInitializer | ||||
@@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers | |||||
private float maxval; | private float maxval; | ||||
private TF_DataType dtype; | private TF_DataType dtype; | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "RandomUniform"; | |||||
public IDictionary<string, object> Config => _config; | |||||
public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | ||||
{ | { | ||||
this.dtype = dtype; | this.dtype = dtype; | ||||
this.minval = minval; | this.minval = minval; | ||||
this.maxval = maxval; | this.maxval = maxval; | ||||
this.seed = seed; | this.seed = seed; | ||||
_config = new Dictionary<string, object>(); | |||||
_config["minval"] = this.minval; | |||||
_config["maxval"] = this.maxval; | |||||
_config["seed"] = this.seed; | |||||
} | } | ||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class TruncatedNormal : IInitializer | public class TruncatedNormal : IInitializer | ||||
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
private int? seed; | private int? seed; | ||||
private TF_DataType dtype; | private TF_DataType dtype; | ||||
private readonly Dictionary<string, object> _config; | |||||
public string ClassName => "TruncatedNormal"; | |||||
public IDictionary<string, object> Config => _config; | |||||
public TruncatedNormal(float mean = 0.0f, | public TruncatedNormal(float mean = 0.0f, | ||||
float stddev = 1.0f, | float stddev = 1.0f, | ||||
int? seed = null, | int? seed = null, | ||||
@@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers | |||||
this.stddev = stddev; | this.stddev = stddev; | ||||
this.seed = seed; | this.seed = seed; | ||||
this.dtype = dtype; | this.dtype = dtype; | ||||
_config = new Dictionary<string, object>(); | |||||
_config["mean"] = this.mean; | |||||
_config["stddev"] = this.stddev; | |||||
_config["seed"] = this.seed; | |||||
} | } | ||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
@@ -15,7 +15,9 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Linq.Expressions; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
@@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers | |||||
protected int? _seed; | protected int? _seed; | ||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
protected bool _uniform; | protected bool _uniform; | ||||
private readonly Dictionary<string, object> _config; | |||||
public virtual string ClassName => "VarianceScaling"; | |||||
public virtual IDictionary<string, object> Config => _config; | |||||
public VarianceScaling(float factor = 2.0f, | public VarianceScaling(float factor = 2.0f, | ||||
string mode = "FAN_IN", | string mode = "FAN_IN", | ||||
@@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers | |||||
_seed = seed; | _seed = seed; | ||||
_dtype = dtype; | _dtype = dtype; | ||||
_uniform = uniform; | _uniform = uniform; | ||||
_config = new(); | |||||
_config["scale"] = _scale; | |||||
_config["mode"] = _mode; | |||||
_config["distribution"] = _distribution; | |||||
_config["seed"] = _seed; | |||||
} | } | ||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
public class Zeros : IInitializer | public class Zeros : IInitializer | ||||
@@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers | |||||
Shape shape; | Shape shape; | ||||
TF_DataType dtype; | TF_DataType dtype; | ||||
public string ClassName => "Zeros"; | |||||
public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
{ | { | ||||
this.shape = shape; | this.shape = shape; | ||||
@@ -20,7 +20,9 @@ using Tensorflow.Keras; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Train; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -75,6 +77,8 @@ namespace Tensorflow | |||||
public Shape BatchInputShape => throw new NotImplementedException(); | public Shape BatchInputShape => throw new NotImplementedException(); | ||||
public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); | |||||
public TF_DataType DType => throw new NotImplementedException(); | public TF_DataType DType => throw new NotImplementedException(); | ||||
protected bool built = false; | protected bool built = false; | ||||
public bool Built => built; | public bool Built => built; | ||||
@@ -143,7 +147,7 @@ namespace Tensorflow | |||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
public LayerArgs get_config() | |||||
public IKerasConfig get_config() | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -152,5 +156,7 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
public Trackable GetTrackable() { throw new NotImplementedException(); } | |||||
} | } | ||||
} | } |
@@ -1,6 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Xml.Linq; | |||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
@@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations | |||||
/// path in the input checkpoint_prefixes. This is useful when those paths are non | /// path in the input checkpoint_prefixes. This is useful when those paths are non | ||||
/// user-facing temporary locations. | /// user-facing temporary locations. | ||||
/// </remarks> | /// </remarks> | ||||
public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") | |||||
{ | |||||
public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||||
checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); | |||||
result = null; | |||||
return null; | |||||
//try | |||||
//{ | |||||
// var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||||
// new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); | |||||
// result = null; | |||||
// return null; | |||||
//} | |||||
//catch (System.Exception) | |||||
//{ | |||||
// return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, | |||||
// allow_missing_files: allow_missing_files, name: name, ctx: ctx); | |||||
//} | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["checkpoint_prefixes"] = checkpoint_prefixes; | dict["checkpoint_prefixes"] = checkpoint_prefixes; | ||||
dict["destination_prefix"] = destination_prefix; | dict["destination_prefix"] = destination_prefix; | ||||
if (delete_old_dirs.HasValue) | |||||
dict["delete_old_dirs"] = delete_old_dirs.Value; | |||||
dict["delete_old_dirs"] = delete_old_dirs; | |||||
var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); | var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); | ||||
return op; | return op; | ||||
} | } | ||||
//public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) | |||||
//{ | |||||
// checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); | |||||
// destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); | |||||
// var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; | |||||
// var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; | |||||
// var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); | |||||
// result = null; | |||||
// return null; | |||||
//} | |||||
/// <summary> | /// <summary> | ||||
/// Transforms a spectrogram into a form that's useful for speech recognition. | /// Transforms a spectrogram into a form that's useful for speech recognition. | ||||
/// </summary> | /// </summary> | ||||
@@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") | public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") | ||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); | |||||
return result[0]; | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["input"] = input; | dict["input"] = input; | ||||
dict["pattern"] = pattern; | dict["pattern"] = pattern; | ||||
@@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") | public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") | ||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); | |||||
return result[0]; | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["basename"] = basename; | dict["basename"] = basename; | ||||
dict["shard"] = shard; | dict["shard"] = shard; | ||||
@@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") | public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") | ||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); | |||||
return result[0]; | |||||
} | |||||
var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
dict["inputs"] = inputs; | dict["inputs"] = inputs; | ||||
if (separator != null) | if (separator != null) | ||||
@@ -14,7 +14,9 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Linq; | |||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -23,11 +25,41 @@ namespace Tensorflow | |||||
{ | { | ||||
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | ||||
{ | { | ||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute( | |||||
new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); | |||||
result = null; | |||||
return null; | |||||
} | |||||
catch (System.Exception) | |||||
{ | |||||
return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); | |||||
} | |||||
} | |||||
var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | ||||
return _op; | return _op; | ||||
} | } | ||||
public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) | |||||
{ | |||||
DataType[] attr_dtypes; | |||||
(attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); | |||||
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||||
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||||
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||||
var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); | |||||
var attrs = new object[] { "dtypes", attr_dtypes }; | |||||
var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); | |||||
result = null; | |||||
return null; | |||||
} | |||||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | ||||
{ | { | ||||
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
@@ -17,6 +17,9 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Variables; | |||||
using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -38,6 +41,11 @@ namespace Tensorflow | |||||
{ | { | ||||
return var is ResourceVariable; | return var is ResourceVariable; | ||||
} | } | ||||
public static bool is_resource_variable(Trackable var) | |||||
{ | |||||
return var is BaseResourceVariable; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Creates a variable handle with information to do shape inference. | /// Creates a variable handle with information to do shape inference. | ||||
@@ -171,5 +179,57 @@ namespace Tensorflow | |||||
return HandleData.Parser.ParseFrom(handle.BufferToArray()); | return HandleData.Parser.ParseFrom(handle.BufferToArray()); | ||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Copies an existing variable to a new graph, with no initializer. | |||||
/// </summary> | |||||
/// <param name="variable"></param> | |||||
public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) | |||||
{ | |||||
var new_variable = new UninitializedVariable( | |||||
trainable: variable.Trainable, | |||||
shape: variable.shape, | |||||
dtype: variable.dtype, | |||||
name: variable.SharedName, | |||||
aggregation: variable.Aggregation, | |||||
extra_handle_data: null); | |||||
new_variable._maybe_initialize_trackable(); | |||||
return new_variable; | |||||
} | |||||
/// <summary> | |||||
/// Writes additional information of the variable into the SavedObject proto. | |||||
/// </summary> | |||||
/// <param name="resource_variable"></param> | |||||
/// <param name="proto"></param> | |||||
/// <param name="options"></param> | |||||
/// <param name="enforcing_naming"></param> | |||||
public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) | |||||
{ | |||||
// lack of API: `proto.Variable.SetInParent()`. | |||||
if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) | |||||
{ | |||||
throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + | |||||
$"unexpected suffix in the name (expected ':0') which won't be restored."); | |||||
} | |||||
if(proto.Variable is null) | |||||
{ | |||||
proto.Variable = new SavedVariable(); | |||||
} | |||||
proto.Variable.Name = meta_graph.op_name(resource_variable.Name); | |||||
proto.Variable.Trainable = resource_variable.Trainable; | |||||
proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); | |||||
// TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. | |||||
proto.Variable.Aggregation = resource_variable.Aggregation; | |||||
proto.Variable.Shape = resource_variable.shape.as_proto(); | |||||
if (options.experimental_variable_policy.save_variable_devices()) | |||||
{ | |||||
if (!string.IsNullOrEmpty(resource_variable.Device)) | |||||
{ | |||||
proto.Variable.Device = resource_variable.Device; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -156,7 +156,7 @@ namespace Tensorflow { | |||||
/// Nodes[0] is considered the root node. | /// Nodes[0] is considered the root node. | ||||
/// </summary> | /// </summary> | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
get { return nodes_; } | get { return nodes_; } | ||||
} | } | ||||
@@ -286,6 +286,7 @@ namespace Tensorflow { | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
public SavedObject(SavedObject other) : this() { | public SavedObject(SavedObject other) : this() { | ||||
children_ = other.children_.Clone(); | children_ = other.children_.Clone(); | ||||
dependencies_ = other.dependencies_.Clone(); | |||||
slotVariables_ = other.slotVariables_.Clone(); | slotVariables_ = other.slotVariables_.Clone(); | ||||
saveableObjects_ = other.saveableObjects_.Clone(); | saveableObjects_ = other.saveableObjects_.Clone(); | ||||
switch (other.KindCase) { | switch (other.KindCase) { | ||||
@@ -328,6 +329,7 @@ namespace Tensorflow { | |||||
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | ||||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | ||||
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||||
/// <summary> | /// <summary> | ||||
/// Objects which this object depends on: named edges in the dependency | /// Objects which this object depends on: named edges in the dependency | ||||
/// graph. | /// graph. | ||||
@@ -338,6 +340,11 @@ namespace Tensorflow { | |||||
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | ||||
get { return children_; } | get { return children_; } | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies { | |||||
get { return dependencies_; } | |||||
} | |||||
/// <summary>Field number for the "slot_variables" field.</summary> | /// <summary>Field number for the "slot_variables" field.</summary> | ||||
public const int SlotVariablesFieldNumber = 3; | public const int SlotVariablesFieldNumber = 3; | ||||
@@ -617,6 +624,7 @@ namespace Tensorflow { | |||||
return; | return; | ||||
} | } | ||||
children_.Add(other.children_); | children_.Add(other.children_); | ||||
dependencies_.Add(other.dependencies_); | |||||
slotVariables_.Add(other.slotVariables_); | slotVariables_.Add(other.slotVariables_); | ||||
saveableObjects_.Add(other.saveableObjects_); | saveableObjects_.Add(other.saveableObjects_); | ||||
switch (other.KindCase) { | switch (other.KindCase) { | ||||
@@ -198,6 +198,22 @@ namespace Tensorflow { | |||||
public TrackableObject() { | public TrackableObject() { | ||||
OnConstruction(); | OnConstruction(); | ||||
} | } | ||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) { | |||||
OnConstruction(); | |||||
slotVariables_ = slot; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot, | |||||
pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children | |||||
) | |||||
{ | |||||
OnConstruction(); | |||||
slotVariables_ = slot; | |||||
children_ = children; | |||||
} | |||||
partial void OnConstruction(); | partial void OnConstruction(); | ||||
@@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||||
<PackageReference Include="Protobuf.Text" Version="0.6.0" /> | <PackageReference Include="Protobuf.Text" Version="0.6.0" /> | ||||
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -202,6 +202,24 @@ namespace Tensorflow | |||||
_ => type.ToString() | _ => type.ToString() | ||||
}; | }; | ||||
public static string as_python_name(this TF_DataType type) | |||||
=> type switch | |||||
{ | |||||
TF_DataType.TF_STRING => "str", | |||||
TF_DataType.TF_UINT8 => "uint8", | |||||
TF_DataType.TF_INT8 => "int8", | |||||
TF_DataType.TF_UINT32 => "uint32", | |||||
TF_DataType.TF_INT32 => "int32", | |||||
TF_DataType.TF_UINT64 => "uint64", | |||||
TF_DataType.TF_INT64 => "int64", | |||||
TF_DataType.TF_FLOAT => "float32", | |||||
TF_DataType.TF_DOUBLE => "float64", | |||||
TF_DataType.TF_BOOL => "bool", | |||||
TF_DataType.TF_RESOURCE => "resource", | |||||
TF_DataType.TF_VARIANT => "variant", | |||||
_ => type.ToString() | |||||
}; | |||||
public static int get_datatype_size(this TF_DataType type) | public static int get_datatype_size(this TF_DataType type) | ||||
=> type.as_base_dtype() switch | => type.as_base_dtype() switch | ||||
{ | { | ||||
@@ -1,6 +1,71 @@ | |||||
namespace Tensorflow.Train | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Operations.Activation; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Train | |||||
{ | { | ||||
public abstract class AutoTrackable : Trackable | |||||
public class AutoTrackable : Trackable | |||||
{ | { | ||||
public void _delete_tracking(string name) | |||||
{ | |||||
_maybe_initialize_trackable(); | |||||
if (_unconditional_dependency_names.ContainsKey(name)) | |||||
{ | |||||
_unconditional_dependency_names.Remove(name); | |||||
for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) | |||||
{ | |||||
if (_unconditional_checkpoint_dependencies[i].Name == name) | |||||
{ | |||||
_unconditional_checkpoint_dependencies.RemoveAt(i); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
if(save_type != SaveType.SAVEDMODEL) | |||||
{ | |||||
return base._trackable_children(save_type, cache); | |||||
} | |||||
Dictionary<string, Trackable> functions = new(); | |||||
// TODO: process of logs. | |||||
var properties = this.GetType().GetProperties(); | |||||
foreach ( var property in properties ) | |||||
{ | |||||
string name = property.Name; | |||||
object value = property.GetValue(this, null); | |||||
if(value is Function || value is ConcreteFunction) | |||||
{ | |||||
functions[name] = (Trackable)value; | |||||
} | |||||
} | |||||
// TODO: process the type `core_types.GenericFunction`. | |||||
Dictionary<string, Trackable> children = new(); | |||||
foreach(var pair in CheckpointDependencies) | |||||
{ | |||||
var name = pair.Name; | |||||
var child = pair.Refer; | |||||
if(child is ConcreteFunction) // or Generic function | |||||
{ | |||||
continue; | |||||
} | |||||
if(functions.ContainsKey(name) && functions[name] != child) | |||||
{ | |||||
throw new ValueError($"Can't save object because it has multiple children with the same " + | |||||
$"name. Object: {this}, attribute name: {name}, child 1: " + | |||||
$"{child}, child 2: {functions[name]}"); | |||||
} | |||||
children[name] = child; | |||||
} | |||||
return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Training | |||||
{ | |||||
public interface IWithTrackable | |||||
{ | |||||
Trackable GetTrackable(); | |||||
} | |||||
} |
@@ -0,0 +1,9 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Training | |||||
{ | |||||
} |
@@ -351,7 +351,7 @@ namespace Tensorflow | |||||
/// <param name="var"></param> | /// <param name="var"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected IVariableV1 get_slot(IVariableV1 var, string name) | |||||
internal IVariableV1 get_slot(IVariableV1 var, string name) | |||||
{ | { | ||||
var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | ||||
if (named_slots == null) | if (named_slots == null) | ||||
@@ -360,6 +360,11 @@ namespace Tensorflow | |||||
return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | ||||
} | } | ||||
internal IEnumerable<string> get_slot_names() | |||||
{ | |||||
return _slots.Keys; | |||||
} | |||||
private string _var_key(IVariableV1 var) | private string _var_key(IVariableV1 var) | ||||
{ | { | ||||
return $"{var.Op.graph.graph_key}.{var.Op.name}"; | return $"{var.Op.graph.graph_key}.{var.Op.name}"; | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class ResourceVariableSaveable : MySaveableObject | public class ResourceVariableSaveable : MySaveableObject | ||||
@@ -35,6 +37,32 @@ namespace Tensorflow | |||||
this.name = name; | this.name = name; | ||||
} | } | ||||
public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) | |||||
{ | |||||
_var_device = var.Device; | |||||
_var_shape = var.shape; | |||||
Tensor _read_variable_closure(BaseResourceVariable v) | |||||
{ | |||||
tf.device(v.Device); | |||||
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
{ | |||||
return null; | |||||
} | |||||
var x = v.read_value_no_copy(); | |||||
tf.device("/device:CPU:0"); | |||||
return array_ops.identity(x); | |||||
} | |||||
this.handle_op = var.Handle; | |||||
var tensor = _read_variable_closure(var); | |||||
var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||||
_op = var; | |||||
specs = new SaveSpec[] { spec }; | |||||
this.name = name; | |||||
} | |||||
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | ||||
{ | { | ||||
var restored_tensor = restored_tensors[0]; | var restored_tensor = restored_tensors[0]; | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow | |||||
public string slice_spec => _slice_spec; | public string slice_spec => _slice_spec; | ||||
private string _name; | private string _name; | ||||
public string name => _name; | |||||
public string name { get => _name; set => _name = value; } | |||||
private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
@@ -14,11 +14,31 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.Checkpoint; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class MySaveableObject | public class MySaveableObject | ||||
{ | { | ||||
public Tensor op; | |||||
protected Maybe<Tensor, BaseResourceVariable> _op; | |||||
public Tensor op | |||||
{ | |||||
get | |||||
{ | |||||
if(_op.DataType == typeof(Tensor)) | |||||
{ | |||||
return _op.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("The _op is not a tensor."); | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
_op = value; | |||||
} | |||||
} | |||||
public SaveSpec[] specs; | public SaveSpec[] specs; | ||||
public string name; | public string name; | ||||
public string device; | public string device; | ||||
@@ -35,7 +55,7 @@ namespace Tensorflow | |||||
public MySaveableObject(Tensor op, SaveSpec[] specs, string name) | public MySaveableObject(Tensor op, SaveSpec[] specs, string name) | ||||
{ | { | ||||
this.op = op; | |||||
this._op = op; | |||||
this.specs = specs; | this.specs = specs; | ||||
this.name = name; | this.name = name; | ||||
} | } | ||||
@@ -48,4 +68,18 @@ namespace Tensorflow | |||||
validate_shape: restored_shapes == null && op.shape.IsFullyDefined); | validate_shape: restored_shapes == null && op.shape.IsFullyDefined); | ||||
} | } | ||||
} | } | ||||
public class NoRestoreSaveable: MySaveableObject | |||||
{ | |||||
public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, | |||||
new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) | |||||
{ | |||||
} | |||||
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||||
{ | |||||
return control_flow_ops.no_op(); | |||||
} | |||||
} | |||||
} | } |
@@ -0,0 +1,11 @@ | |||||
using System.Collections.Generic; | |||||
namespace Tensorflow; | |||||
public record class AssetInfo | |||||
( | |||||
List<AssetFileDef> asset_defs, | |||||
Dictionary<object, object> asset_initializers_by_resource, | |||||
Dictionary<AssetInfo, string> asset_filename_map, | |||||
Dictionary<object, object> asset_index | |||||
); |
@@ -0,0 +1,133 @@ | |||||
using System; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Train; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
namespace Tensorflow; | |||||
public class AugmentedGraphView: ObjectGraphView | |||||
{ | |||||
private Dictionary<Trackable, IDictionary<string, Trackable>> _children_cache; | |||||
private Dictionary<string, IDictionary<Trackable, ISerializedAttributes>> _serialization_cache; | |||||
private List<string> _untraces_functions; | |||||
private Dictionary<ConcreteFunction, ConcreteFunction> _wrapped_functions; | |||||
public AugmentedGraphView(Trackable root): base(root) | |||||
{ | |||||
_children_cache= new Dictionary<Trackable, IDictionary<string, Trackable>>(); | |||||
_serialization_cache = new Dictionary<string, IDictionary<Trackable, ISerializedAttributes>>(); | |||||
_untraces_functions = new List<string>(); | |||||
_wrapped_functions = new Dictionary<ConcreteFunction, ConcreteFunction>(); | |||||
} | |||||
public void set_signature(SignatureMap signature_map, IDictionary<ConcreteFunction, ConcreteFunction> wrapped_functions) | |||||
{ | |||||
list_children(Root); | |||||
var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; | |||||
if (!_children_cache.ContainsKey(Root)) | |||||
{ | |||||
_children_cache[Root] = new Dictionary<string, Trackable>(); | |||||
} | |||||
_children_cache[Root][name] = signature_map; | |||||
_wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null) | |||||
{ | |||||
if(serialization_cache is not null) | |||||
{ | |||||
throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); | |||||
} | |||||
if (!_children_cache.ContainsKey(obj)) | |||||
{ | |||||
Dictionary<string, Trackable> children = new Dictionary<string, Trackable>(); | |||||
_children_cache[obj] = children; | |||||
foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) | |||||
{ | |||||
var name = pair.Name; | |||||
var child = pair.Refer; | |||||
if(child is ConcreteFunction) | |||||
{ | |||||
child = maybe_uncache_variable_captures((ConcreteFunction)child); | |||||
} | |||||
children[name] = child; | |||||
} | |||||
if (obj is Function && children.Count == 0) | |||||
{ | |||||
_untraces_functions.Add(((Function)obj).Name); | |||||
} | |||||
} | |||||
List<TrackableReference> res = new(); | |||||
foreach(var pair in _children_cache[obj]) | |||||
{ | |||||
res.Add(new TrackableReference(pair.Key, pair.Value)); | |||||
} | |||||
return res; | |||||
} | |||||
private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) | |||||
{ | |||||
if (_wrapped_functions.ContainsKey(concrete_function)) | |||||
{ | |||||
return _wrapped_functions[concrete_function]; | |||||
} | |||||
// skip the process here because of lack of feature. | |||||
// In the future, we may add an attribute which could specify if the variable is supposed to be cached. | |||||
//foreach(var capture in concrete_function.CapturedInputs) | |||||
//{ | |||||
//} | |||||
return concrete_function; | |||||
} | |||||
public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
{ | |||||
Trackable get_merged_trackable(Trackable x) | |||||
{ | |||||
// TODO: complete it with new definitions `Asset` and `TrackableConstant`. | |||||
return x; | |||||
} | |||||
var trackable_objects = base.breadth_first_traversal(); | |||||
foreach(var obj in _children_cache.Keys) | |||||
{ | |||||
// skip the deletion of cache (maybe do it later). | |||||
foreach(var pair in _children_cache[obj]) | |||||
{ | |||||
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); | |||||
} | |||||
} | |||||
return base.breadth_first_traversal(); | |||||
} | |||||
public List<(string, Trackable)> list_dependencies(Trackable obj) | |||||
{ | |||||
IDictionary<string, Trackable> children; | |||||
if (!_children_cache.ContainsKey(obj)) | |||||
{ | |||||
children= new Dictionary<string, Trackable>(); | |||||
} | |||||
else | |||||
{ | |||||
children= _children_cache[obj]; | |||||
} | |||||
List<(string, Trackable)> res = new(); | |||||
foreach(var pair in obj.deserialization_dependencies(children)) | |||||
{ | |||||
res.Add((pair.Key, pair.Value)); | |||||
} | |||||
return res; | |||||
} | |||||
public Trackable get_child(Trackable obj, string name) | |||||
{ | |||||
return _children_cache[obj][name]; | |||||
} | |||||
} |
@@ -0,0 +1,33 @@ | |||||
namespace Tensorflow; | |||||
public static class Constants | |||||
{ | |||||
public static readonly string ASSETS_DIRECTORY = "assets"; | |||||
public static readonly string ASSETS_KEY = "saved_model_assets"; | |||||
public static readonly string DEBUG_DIRECTORY = "debug"; | |||||
public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; | |||||
public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; | |||||
public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; | |||||
public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; | |||||
public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; | |||||
public static readonly string MAIN_OP_KEY = "saved_model_main_op"; | |||||
public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; | |||||
public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; | |||||
public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; | |||||
public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; | |||||
public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; | |||||
public static readonly string VARIABLES_DIRECTORY = "variables"; | |||||
public static readonly string VARIABLES_FILENAME = "variables"; | |||||
} |
@@ -0,0 +1,17 @@ | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow; | |||||
public class RevivedTypes | |||||
{ | |||||
/// <summary> | |||||
/// Create a SavedUserObject from a trackable object. | |||||
/// </summary> | |||||
/// <param name="obj"></param> | |||||
/// <returns></returns> | |||||
public static SavedUserObject? serialize(Trackable obj) | |||||
{ | |||||
// TODO: complete the implementation. | |||||
return null; | |||||
} | |||||
} |
@@ -0,0 +1,9 @@ | |||||
using System; | |||||
namespace Tensorflow; | |||||
public enum SaveType | |||||
{ | |||||
SAVEDMODEL, | |||||
CHECKPOINT | |||||
} |
@@ -0,0 +1,299 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow; | |||||
public class SaveableView | |||||
{ | |||||
private AugmentedGraphView _augmented_graph_view; | |||||
private SaveOptions _options; | |||||
private List<Trackable> _trackable_objects; | |||||
private List<Trackable> _nodes; | |||||
private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||||
private Dictionary<Trackable, int> _node_ids; | |||||
private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
_slot_variables; | |||||
private Dictionary<Trackable, string> _object_names; | |||||
private List<object> _gradient_functions; // to be completed | |||||
private List<RegisteredGradient> _gradient_defs; // to be completed | |||||
private List<ConcreteFunction> _concrete_functions; | |||||
private Dictionary<Tensor, int> _captured_tensor_node_ids; | |||||
private Dictionary<Trackable, IDictionary<string, ConcreteFunction>> _saveable_objects_map; | |||||
private Dictionary<Trackable, string> _obj_to_registered_saver; | |||||
public AugmentedGraphView AugmentedGraphView | |||||
{ | |||||
get => _augmented_graph_view; | |||||
} | |||||
public Trackable Root | |||||
{ | |||||
get => _nodes[0]; | |||||
} | |||||
public List<Trackable> Nodes | |||||
{ | |||||
get => _nodes; | |||||
} | |||||
public Dictionary<Trackable, int> NodeIds | |||||
{ | |||||
get => _node_ids; | |||||
} | |||||
public List<RegisteredGradient> GradientDefs | |||||
{ | |||||
get => _gradient_defs; | |||||
} | |||||
public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||||
{ | |||||
get => _node_paths; | |||||
} | |||||
public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) | |||||
{ | |||||
_augmented_graph_view = augmented_graph_view; | |||||
_options = options; | |||||
(_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = | |||||
CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); | |||||
// TODO: deal with untraced functions. | |||||
initialize_save_and_restore_functions(); | |||||
initialize_nodes_and_concrete_functions(); | |||||
_captured_tensor_node_ids = new(); | |||||
} | |||||
private void initialize_save_and_restore_functions() | |||||
{ | |||||
// TODO: deal with the return value of `get_checkpoint_factories_and_keys`. | |||||
var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); | |||||
// skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. | |||||
_obj_to_registered_saver = new(); | |||||
_saveable_objects_map = new(); | |||||
} | |||||
private void initialize_nodes_and_concrete_functions() | |||||
{ | |||||
_nodes = _trackable_objects.ConvertAll(x => x); // deep copy | |||||
_gradient_functions = new(); | |||||
_gradient_defs = new(); | |||||
// TODO: deal with the condition that obj in `_saveable_objects_map`. | |||||
// foreach (var obj in _nodes) | |||||
// { | |||||
// | |||||
// } | |||||
foreach (var obj in _nodes) | |||||
{ | |||||
if (obj is ConcreteFunction) | |||||
{ | |||||
_concrete_functions.Add((ConcreteFunction)obj); | |||||
} | |||||
} | |||||
} | |||||
public List<ConcreteFunction> get_concrete_resource_initializers() | |||||
{ | |||||
// TODO: complete the implementation. | |||||
return new List<ConcreteFunction>(); | |||||
} | |||||
public (Dictionary<Trackable, Trackable>, Dictionary<Tensor, Tensor>, AssetInfo) map_resources() | |||||
{ | |||||
Debug.Assert(!tf.Context.executing_eagerly()); | |||||
Dictionary<Trackable, Trackable> object_map = new(); | |||||
Dictionary<Tensor, Tensor> tensor_map = new(); | |||||
AssetInfo assetInfo = new(new List<AssetFileDef>(), new Dictionary<object, object>(), | |||||
new Dictionary<AssetInfo, string>(), new Dictionary<object, object>()); | |||||
foreach (var node_id in dependency_sorted_node_ids()) | |||||
{ | |||||
var obj = _nodes[node_id]; | |||||
var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); | |||||
// TODO: deal with Asset (if obj is Asset) | |||||
foreach (var tensor in tensors) | |||||
{ | |||||
_captured_tensor_node_ids[tensor] = node_id; | |||||
} | |||||
} | |||||
return (object_map, tensor_map, assetInfo); | |||||
} | |||||
/// <summary> | |||||
/// Returns topologically sorted nodes, sorted by dependencies. | |||||
/// </summary> | |||||
public List<int> dependency_sorted_node_ids() | |||||
{ | |||||
Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||||
foreach (var node in _nodes) | |||||
{ | |||||
var node_id = _node_ids[node]; | |||||
List<int> deps = new List<int>(); | |||||
dependency_map.Add(node_id, deps); | |||||
// TODO: deal with captured tensor. | |||||
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | |||||
{ | |||||
if (!_node_ids.ContainsKey(dep)) | |||||
{ | |||||
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||||
throw new ValueError( | |||||
$"Found an untracked dependency. Object {node_path} depends on {dep}, " + | |||||
$"but this dependency isn't listed as a child. Please track this child by " + | |||||
$"overriding `_trackable_children` or use `._track_trackable`."); | |||||
} | |||||
deps.Add(_node_ids[dep]); | |||||
} | |||||
} | |||||
try | |||||
{ | |||||
return TrackableUtils.order_by_dependency(dependency_map); | |||||
} | |||||
catch (TrackableUtils.CyclicDependencyError err) | |||||
{ | |||||
List<string> pretty_printed_nodes = new(); | |||||
List<string> pretty_printed_dependencies = new(); | |||||
foreach (var pair in err.LeftOverDependencyMap) | |||||
{ | |||||
var x = pair.Key; | |||||
var deps = pair.Value; | |||||
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); | |||||
pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); | |||||
pretty_printed_dependencies.Add( | |||||
$"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); | |||||
} | |||||
throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + | |||||
$"Saving cannot continue until this cycle is resolved." + | |||||
$"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + | |||||
$"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph | |||||
/// </summary> | |||||
/// <param name="asset_index"></param> | |||||
/// <returns></returns> | |||||
public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index) | |||||
{ | |||||
SavedObjectGraph proto = new(); | |||||
fill_object_graph_proto(proto); | |||||
// TODO: complete the process of concrete functions. | |||||
int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); | |||||
for (int i = 0; i < cnt; i++) | |||||
{ | |||||
var obj = _nodes[i]; | |||||
var obj_proto = proto.Nodes[i]; | |||||
write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); | |||||
} | |||||
return proto; | |||||
} | |||||
private static void write_object_proto(Trackable obj, SavedObject proto, | |||||
IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn) | |||||
{ | |||||
// skip the process of type Asset | |||||
if (resource_variable_ops.is_resource_variable(obj)) | |||||
{ | |||||
var options = SaveContext.get_save_options(); | |||||
(obj as BaseResourceVariable).write_object_proto(proto, options); | |||||
} | |||||
else if (obj is Function) | |||||
{ | |||||
// TODO: complete it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else if (obj is ConcreteFunction) | |||||
{ | |||||
// TODO: complete it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
// skip the process of type `_CapturedTensor` and `CapturableResource`. | |||||
else | |||||
{ | |||||
var registered_type_proto = RevivedTypes.serialize(obj); | |||||
if (registered_type_proto is null) | |||||
{ | |||||
registered_type_proto = new SavedUserObject() | |||||
{ | |||||
Identifier = obj.ObjectIdentifier, | |||||
Version = new VersionDef() | |||||
{ | |||||
Producer = 1, | |||||
MinConsumer = 1, | |||||
BadConsumers = { } | |||||
} | |||||
}; | |||||
} | |||||
proto.UserObject = new SavedUserObject(registered_type_proto); | |||||
} | |||||
// TODO: try get the registered_name from `registration`. | |||||
} | |||||
public void fill_object_graph_proto(SavedObjectGraph proto) | |||||
{ | |||||
for (int node_id = 0; node_id < _nodes.Count; node_id++) | |||||
{ | |||||
var node = _nodes[node_id]; | |||||
Debug.Assert(_node_ids[node] == node_id); | |||||
SavedObject object_proto = new(); | |||||
if (_slot_variables.TryGetValue(node, out var value)) | |||||
{ | |||||
object_proto.SlotVariables.AddRange(value); | |||||
} | |||||
// skip the check of type `_CapturedTensor` | |||||
foreach (var child in _augmented_graph_view.list_children(node)) | |||||
{ | |||||
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||||
child_proto.NodeId = _node_ids[child.Refer]; | |||||
child_proto.LocalName = child.Name; | |||||
object_proto.Children.Add(child_proto); | |||||
} | |||||
foreach (var pair in _augmented_graph_view.list_dependencies(node)) | |||||
{ | |||||
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||||
child_proto.NodeId = _node_ids[pair.Item2]; | |||||
child_proto.LocalName = pair.Item1; | |||||
object_proto.Dependencies.Add(child_proto); | |||||
} | |||||
if (_saveable_objects_map.ContainsKey(node)) | |||||
{ | |||||
// TODO: complete it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
else if(_obj_to_registered_saver.ContainsKey(node)) | |||||
{ | |||||
// TODO: complete it. | |||||
// We now skip it for the lack of `SavedObject.registered_saver` API. | |||||
throw new NotImplementedException(); | |||||
} | |||||
proto.Nodes.Add(object_proto); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
namespace Tensorflow; | |||||
public static class TagConstants | |||||
{ | |||||
public static readonly string SERVING = "serve"; | |||||
public static readonly string TRAINING = "train"; | |||||
public static readonly string EVAL = "eval"; | |||||
public static readonly string GPU = "gpu"; | |||||
public static readonly string TPU = "tpu"; | |||||
} |
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow; | |||||
public class BuilderUtils | |||||
{ | |||||
public static void copy_assets_to_destination_dir(IDictionary<AssetInfo, string> asset_filename_map, | |||||
string destination_dir, HashSet<string>? saved_files = null) | |||||
{ | |||||
if (saved_files is null) saved_files = new HashSet<string>(); | |||||
var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); | |||||
// TODO: complete the implementation of this function. | |||||
if (asset_filename_map is not null && asset_filename_map.Count > 0) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,269 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Google.Protobuf; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Exceptions; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Training.Saving.SavedModel; | |||||
namespace Tensorflow; | |||||
public static partial class SavedModelUtils | |||||
{ | |||||
private static readonly IEnumerable<int> byte_swappable = new List<TF_DataType>() | |||||
{ | |||||
dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, | |||||
dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, | |||||
dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, | |||||
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 | |||||
}.Select(x => (int)x); | |||||
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | |||||
string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||||
{ | |||||
if (options is null) | |||||
{ | |||||
options = new SaveOptions(); | |||||
} | |||||
var saved_model = new Tensorflow.SavedModel(); | |||||
var meta_graph_def = new MetaGraphDef(); | |||||
saved_model.MetaGraphs.Add(meta_graph_def); | |||||
var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = | |||||
_build_meta_graph(obj, signatures, options, meta_graph_def); | |||||
saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; | |||||
if (!experimental_skip_checkpoint) | |||||
{ | |||||
SavedModelUtils.get_or_create_variables_dir(export_dir); | |||||
CheckpointOptions ckpt_options = new(options.experimental_io_device); | |||||
object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); | |||||
} | |||||
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
// tensorflow python has a check of `context.async_wait()` here. | |||||
} | |||||
// TODO: deal with `pywrap_saved_model.Save(export_dir)`. | |||||
var saved_model_serialized = saved_model.ToString(); | |||||
// This is a state depending on some py-c APIs. Here we temporarily set it as `true`. | |||||
if (true) | |||||
{ | |||||
var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), | |||||
tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); | |||||
// TODO: add c api and complete the fingerprint def. | |||||
var fingerprint_proto = ""; | |||||
File.WriteAllText(fingerprint_path, fingerprint_proto); | |||||
} | |||||
var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | |||||
File.WriteAllBytes(path, saved_model.ToByteArray()); | |||||
//File.WriteAllText(path, saved_model.ToString()); | |||||
if (options.save_debug_info) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
ops.dismantle_graph(exported_graph); | |||||
return (saved_nodes, node_paths); | |||||
} | |||||
private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||||
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||||
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||||
{ | |||||
using (SaveContext.save_context(options)) | |||||
{ | |||||
if (ops.inside_function()) | |||||
{ | |||||
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + | |||||
"Move the call to the outer eagerly-executed context."); | |||||
} | |||||
if (meta_graph_def is null) | |||||
{ | |||||
meta_graph_def = new MetaGraphDef(); | |||||
} | |||||
AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||||
if (signatures is null) | |||||
{ | |||||
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); | |||||
} | |||||
// TODO: process of aignatures and wrapped_functions | |||||
SaveableView saveable_view = new SaveableView(augmented_graph_view, options); | |||||
TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); | |||||
var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, | |||||
options.namespace_white_list, options.experimental_custom_gradients); | |||||
if (options.function_aliases is not null) | |||||
{ | |||||
var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; | |||||
foreach (var pair in options.function_aliases) | |||||
{ | |||||
var alias = pair.Key; | |||||
var func = pair.Value; | |||||
// TODO: complete it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); | |||||
meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); | |||||
return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); | |||||
} | |||||
} | |||||
private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | |||||
ConcreteFunction signatures, IEnumerable<string> namespace_whitelist, | |||||
bool save_custom_gradients) | |||||
{ | |||||
var resource_initializers = saveable_view.get_concrete_resource_initializers(); | |||||
var exported_graph = new Graph(); | |||||
Dictionary<Trackable, Trackable> object_map; | |||||
Dictionary<Tensor, Tensor> tensor_map; | |||||
AssetInfo asset_info; | |||||
var g = exported_graph.as_default(); | |||||
(object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||||
// TODO: deal with signatures. | |||||
if (save_custom_gradients) | |||||
{ | |||||
// TODO: trace gradient functions. | |||||
} | |||||
foreach (var resource_initializer_function in resource_initializers) | |||||
{ | |||||
// List<Trackable> asset_dependencies = new(); | |||||
// TODO: deal with initializers | |||||
} | |||||
// using(ops.control_dependencies(...)) | |||||
var init_op = control_flow_ops.no_op(); | |||||
if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) | |||||
{ | |||||
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); | |||||
} | |||||
else | |||||
{ | |||||
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); | |||||
} | |||||
// Lack `CopyFrom` API | |||||
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||||
g.Exit(); | |||||
foreach (var obj in object_map.Values) | |||||
{ | |||||
obj._maybe_initialize_trackable(); | |||||
} | |||||
// TODO: add the implementation of `call_with_mapped_functions`. | |||||
var (named_saveable_objects, registered_savers) = | |||||
SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); | |||||
var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); | |||||
var eg = exported_graph.as_default(); | |||||
var saver_def = saver.to_proto(); | |||||
meta_graph_def.SaverDef = saver_def; | |||||
eg.Exit(); | |||||
saveable_view.dependency_sorted_node_ids(); | |||||
var graph_def = exported_graph.as_graph_def(true); | |||||
graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); | |||||
verify_ops(graph_def, namespace_whitelist); | |||||
meta_graph_def.GraphDef = new GraphDef(graph_def); | |||||
meta_graph_def.MetaInfoDef = new(); | |||||
meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); | |||||
meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; | |||||
// TODO: add git version. | |||||
meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; | |||||
meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||||
meta_graph_def.MetaInfoDef.StrippedOpList = new(); | |||||
meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); | |||||
meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); | |||||
// TODO: deal with signatures here. | |||||
meta_graph.strip_graph_default_valued_attrs(meta_graph_def); | |||||
if (!BitConverter.IsLittleEndian) | |||||
{ | |||||
swap_function_tensor_content(meta_graph_def); | |||||
} | |||||
return (asset_info, exported_graph); | |||||
} | |||||
private static void verify_ops(GraphDef graph_def, IEnumerable<string>? namespace_whitelist) | |||||
{ | |||||
return; | |||||
// if (namespace_whitelist is null || !namespace_whitelist.Any()) | |||||
// { | |||||
// return; | |||||
// } | |||||
// skip the check for the lack of `meta_graph.ops_used_by_graph_def`. | |||||
} | |||||
public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) | |||||
{ | |||||
var functions = meta_graph_def.GraphDef.Library.Function; | |||||
foreach (var function in functions) | |||||
{ | |||||
var node_def = function.NodeDef; | |||||
foreach (var node in node_def) | |||||
{ | |||||
if (node.Op == "Const") | |||||
{ | |||||
var tensor = node.Attr["value"].Tensor; | |||||
byte_swap_tensor_content(tensor); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
public static void byte_swap_tensor_content(TensorProto tensor) | |||||
{ | |||||
if (byte_swappable.Contains((int)tensor.Dtype)) | |||||
{ | |||||
var tshape = tensor.TensorShape.Dim; | |||||
var tensor_bytes = tensor.TensorContent; | |||||
if (tensor_bytes is not null && !tensor_bytes.IsEmpty) | |||||
{ | |||||
long tensor_size = 1; | |||||
foreach (var sz in tshape) | |||||
{ | |||||
tensor_size *= sz.Size; | |||||
} | |||||
var chunksize = tensor_bytes.Length / tensor_size; | |||||
List<byte> reversed_bytes = new(); | |||||
for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) | |||||
{ | |||||
var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); | |||||
reversed_bytes.AddRange(current); | |||||
} | |||||
tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,53 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.ModelSaving; | |||||
namespace Tensorflow.Training.Saving.SavedModel | |||||
{ | |||||
/// <summary> | |||||
/// A context for building a graph of SavedModel. | |||||
/// </summary> | |||||
public static class SaveContext | |||||
{ | |||||
// TODO: make it thead safe. | |||||
private static bool _in_save_context = false; | |||||
private static SaveOptions _save_options = null; | |||||
public static bool in_save_context() => _in_save_context; | |||||
public static SaveOptions get_save_options() | |||||
{ | |||||
if (!in_save_context()) | |||||
{ | |||||
throw new ValueError("Not in a SaveContext."); | |||||
} | |||||
return _save_options; | |||||
} | |||||
public static SaveContextHandler save_context(SaveOptions options) | |||||
{ | |||||
return new SaveContextHandler(options); | |||||
} | |||||
public class SaveContextHandler: IDisposable | |||||
{ | |||||
private bool _old_in_save_context; | |||||
private SaveOptions _old_save_options; | |||||
public SaveContextHandler(SaveOptions options) | |||||
{ | |||||
if (SaveContext.in_save_context()) | |||||
{ | |||||
throw new ValueError("Already in a SaveContext."); | |||||
} | |||||
_old_in_save_context = SaveContext._in_save_context; | |||||
SaveContext._in_save_context = true; | |||||
_old_save_options = SaveContext._save_options; | |||||
SaveContext._save_options = options; | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
SaveContext._in_save_context = _old_in_save_context; | |||||
SaveContext._save_options = _old_save_options; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,107 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow; | |||||
public static class SignatureSerializationUtils | |||||
{ | |||||
internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; | |||||
internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; | |||||
internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; | |||||
public static SignatureMap create_signature_map(IDictionary<string, Trackable> signatures) | |||||
{ | |||||
var signature_map = new SignatureMap(); | |||||
foreach (var pair in signatures) | |||||
{ | |||||
var name = pair.Key; | |||||
var func = pair.Value; | |||||
Debug.Assert(func is ConcreteFunction); | |||||
// TODO: assert the `func.structured_outputs` and arg_keywords. | |||||
signature_map._add_signature(name, (ConcreteFunction)func); | |||||
} | |||||
return signature_map; | |||||
} | |||||
public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) | |||||
{ | |||||
var children = graph_view.list_children(graph_view.Root); | |||||
List<Trackable> possible_signatures = new(); | |||||
foreach (var item in children) | |||||
{ | |||||
var name = item.Name; | |||||
var child = item.Refer; | |||||
if(child is not (Function or ConcreteFunction)) | |||||
{ | |||||
continue; | |||||
} | |||||
if(name == DEFAULT_SIGNATURE_ATTR) | |||||
{ | |||||
Debug.Assert(child is ConcreteFunction); | |||||
return (ConcreteFunction)child; | |||||
} | |||||
ConcreteFunction concrete = get_signature(child); | |||||
if(concrete is not null && valid_signature(concrete)) | |||||
{ | |||||
possible_signatures.Add(concrete); | |||||
} | |||||
} | |||||
if(possible_signatures.Count == 1) | |||||
{ | |||||
var signature = get_signature(possible_signatures[0]); | |||||
if(signature is not null && valid_signature(signature)) | |||||
{ | |||||
return signature; | |||||
} | |||||
} | |||||
return null; | |||||
} | |||||
private static ConcreteFunction get_signature(Trackable function) | |||||
{ | |||||
// TODO: implement it. | |||||
return null; | |||||
} | |||||
private static bool valid_signature(ConcreteFunction concreate_function) | |||||
{ | |||||
// TODO: implement it. | |||||
return false; | |||||
} | |||||
} | |||||
public class SignatureMap: Trackable | |||||
{ | |||||
private Dictionary<string, Trackable> _signatures; | |||||
public SignatureMap() | |||||
{ | |||||
_signatures = new(); | |||||
} | |||||
public void _add_signature(string name, ConcreteFunction concrete_function) | |||||
{ | |||||
_signatures[name] = concrete_function; | |||||
} | |||||
public void _add_signature(string name, Function concrete_function) | |||||
{ | |||||
_signatures[name] = concrete_function; | |||||
} | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
if (save_type != SaveType.SAVEDMODEL) | |||||
{ | |||||
return new Dictionary<string, Trackable>(); | |||||
} | |||||
return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} |
@@ -0,0 +1,57 @@ | |||||
using System.IO; | |||||
using System.Security.Cryptography.X509Certificates; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow; | |||||
public static partial class SavedModelUtils | |||||
{ | |||||
/// <summary> | |||||
/// Return variables sub-directory, or create one if it doesn't exist. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static string get_or_create_variables_dir(string export_dir) | |||||
{ | |||||
var variables_dir = get_variables_dir(export_dir); | |||||
Directory.CreateDirectory(variables_dir); | |||||
return variables_dir; | |||||
} | |||||
/// <summary> | |||||
/// Return variables sub-directory in the SavedModel. | |||||
/// </summary> | |||||
/// <param name="export_dir"></param> | |||||
/// <returns></returns> | |||||
public static string get_variables_dir(string export_dir) | |||||
{ | |||||
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); | |||||
} | |||||
public static string get_variables_path(string export_dir) | |||||
{ | |||||
return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); | |||||
} | |||||
/// <summary> | |||||
/// Return assets sub-directory, or create one if it doesn't exist. | |||||
/// </summary> | |||||
/// <param name="export_dir"></param> | |||||
/// <returns></returns> | |||||
public static string get_or_create_assets_dir(string export_dir) | |||||
{ | |||||
var assets_destination_dir = get_assets_dir(export_dir); | |||||
Directory.CreateDirectory(assets_destination_dir); | |||||
return assets_destination_dir; | |||||
} | |||||
/// <summary> | |||||
/// Return path to asset directory in the SavedModel. | |||||
/// </summary> | |||||
/// <param name="export_dir"></param> | |||||
/// <returns></returns> | |||||
public static string get_assets_dir(string export_dir) | |||||
{ | |||||
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | |||||
} | |||||
} |
@@ -16,12 +16,38 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Operations.Activation; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class saveable_object_util | |||||
/// <summary> | |||||
/// A SaveableObject that defines `Trackable` checkpointing steps. | |||||
/// </summary> | |||||
public class TrackableSaveable : MySaveableObject | |||||
{ | |||||
private string _prefix; | |||||
private IEnumerable<string> _local_names; | |||||
private Trackable _trackable; | |||||
private bool _call_with_mapped_captures; | |||||
// TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. | |||||
public TrackableSaveable(Trackable obj, IEnumerable<SaveSpec> specs, string name, IEnumerable<string> local_names, | |||||
string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) | |||||
{ | |||||
_prefix = prefix; | |||||
_trackable = obj; | |||||
_local_names = local_names; | |||||
_call_with_mapped_captures = call_with_mapped_captures; | |||||
} | |||||
// TODO: complete this class. | |||||
} | |||||
public static class saveable_object_util | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Returns the variables and names that will be used for a Saver. | /// Returns the variables and names that will be used for a Saver. | ||||
@@ -52,7 +78,7 @@ namespace Tensorflow | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Create `SaveableObject`s from an operation. | |||||
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||||
/// </summary> | /// </summary> | ||||
/// <param name="op"></param> | /// <param name="op"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
@@ -74,6 +100,73 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Create `SaveableObject`s from an operation. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(Trackable obj, string name) | |||||
{ | |||||
// The `op` maybe `Variable` or `Trackable`. | |||||
if (obj is BaseResourceVariable) | |||||
{ | |||||
var variable = obj as BaseResourceVariable; | |||||
if (variable.InGraphMode) | |||||
{ | |||||
yield return new ResourceVariableSaveable(variable.GraphElement, "", name); | |||||
} | |||||
else | |||||
{ | |||||
yield return new ResourceVariableSaveable(variable, "", name); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
foreach(var pair in saveable_objects_from_trackable(obj)) | |||||
{ | |||||
var attr = pair.Key; | |||||
var factory = pair.Value; | |||||
string full_name; | |||||
if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) | |||||
{ | |||||
full_name = name; | |||||
} | |||||
else | |||||
{ | |||||
full_name = name + "_" + attr; | |||||
} | |||||
if(factory.DataType == typeof(ResourceVariable)) | |||||
{ | |||||
var variable = factory.GetValueA(); | |||||
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
{ | |||||
yield return op; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
var variable = factory.GetValueB(); | |||||
foreach (var op in saveable_objects_for_op(variable, variable.name)) | |||||
{ | |||||
yield return op; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Create `SaveableObject`s from an operation. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<MySaveableObject> saveable_objects_for_op(MySaveableObject obj, string name) | |||||
{ | |||||
yield return obj; | |||||
} | |||||
public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | public static Dictionary<string, Tensor> op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) | ||||
{ | { | ||||
op_list = op_list.OrderBy(x => x.Name).ToArray(); | op_list = op_list.OrderBy(x => x.Name).ToArray(); | ||||
@@ -121,5 +214,164 @@ namespace Tensorflow | |||||
return names_to_saveables; | return names_to_saveables; | ||||
} | } | ||||
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||||
{ | |||||
// skip the process of type `PythonState` | |||||
if (trackable_has_serialize_to_tensor(obj)) | |||||
{ | |||||
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||||
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||||
var tensor_dict = obj.serialize_to_tensors(); | |||||
List<SaveSpec> specs = new(); | |||||
List<string> local_names = new(); | |||||
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||||
foreach(var pair in tensor_dict) | |||||
{ | |||||
var tensor_name = pair.Key; | |||||
var maybe_tensor = pair.Value; | |||||
local_names.Add(tensor_name); | |||||
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||||
IDictionary<string, Tensor> internal_dict; | |||||
if(maybe_tensor.DataType == typeof(Tensor)) | |||||
{ | |||||
internal_dict= new Dictionary<string, Tensor>(); | |||||
internal_dict[""] = maybe_tensor.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
internal_dict = maybe_tensor.GetValueB(); | |||||
} | |||||
foreach(var item in internal_dict) | |||||
{ | |||||
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||||
} | |||||
} | |||||
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||||
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
return res; | |||||
} | |||||
else | |||||
{ | |||||
return obj.gather_saveables_for_checkpoint(); | |||||
} | |||||
} | |||||
public static bool trackable_has_serialize_to_tensor(Trackable obj) | |||||
{ | |||||
return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); | |||||
} | |||||
internal static string convert_to_string(string x) | |||||
{ | |||||
return tf.compat.as_str(x); | |||||
} | |||||
/// <summary> | |||||
/// Converts a list of SaveableObjects to a tensor dictionary. | |||||
/// </summary> | |||||
/// <param name="saveables"></param> | |||||
public static Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> saveable_object_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
{ | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||||
foreach (var saveable in saveables) | |||||
{ | |||||
foreach (var spec in saveable.specs) | |||||
{ | |||||
// skip the check that if `spec` is callable. | |||||
var name = convert_to_string(spec.name); | |||||
var slice_spec = convert_to_string(spec.slice_spec); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | |||||
{ | |||||
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||||
} | |||||
else | |||||
{ | |||||
tensor_dict[name] = spec.tensor; | |||||
} | |||||
} | |||||
} | |||||
return tensor_dict; | |||||
} | |||||
/// <summary> | |||||
/// Generates `Trackable._restore_from_tensors` from SaveableObjects. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||||
{ | |||||
return (restored_tensors) => | |||||
{ | |||||
Dictionary<string, Operation> restored_ops = new(); | |||||
foreach(var saveable in saveables) | |||||
{ | |||||
List<Tensor> saveable_restored_tensors = new(); | |||||
foreach(var spec in saveable.specs) | |||||
{ | |||||
var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); | |||||
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | |||||
var maybe_tensor = restored_tensors[name]; | |||||
IDictionary<string, Tensor> dict; | |||||
if(maybe_tensor.DataType == typeof(Tensor)) | |||||
{ | |||||
dict = new Dictionary<string, Tensor>(); | |||||
dict[""] = maybe_tensor.GetValueA(); | |||||
} | |||||
else | |||||
{ | |||||
dict = maybe_tensor.GetValueB(); | |||||
} | |||||
saveable_restored_tensors.Add(dict[slice_spec]); | |||||
} | |||||
restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); | |||||
} | |||||
return restored_ops; | |||||
}; | |||||
} | |||||
} | |||||
public class SaveableCompatibilityConverter: Trackable | |||||
{ | |||||
private object _obj; | |||||
private IList<MySaveableObject> _saveables; | |||||
public SaveableCompatibilityConverter(object obj, IList<MySaveableObject> saveables) | |||||
{ | |||||
_obj= obj; | |||||
_saveables= saveables; | |||||
} | |||||
public object Obj => _obj; | |||||
public IList<MySaveableObject> mySaveables=> _saveables; | |||||
public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
{ | |||||
return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | |||||
} | |||||
/// <summary> | |||||
/// Returns the restore ops defined in the Saveables. | |||||
/// </summary> | |||||
/// <param name="restored_tensors"></param> | |||||
/// <returns></returns> | |||||
public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
{ | |||||
List<string> expected_keys = new(); | |||||
foreach(var saveable in _saveables) | |||||
{ | |||||
expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); | |||||
} | |||||
if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) | |||||
{ | |||||
throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + | |||||
$"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); | |||||
} | |||||
return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); | |||||
} | |||||
} | } | ||||
} | } |
@@ -14,13 +14,63 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.ModelSaving; | |||||
using Tensorflow.Training; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
{ | { | ||||
public abstract class Trackable | |||||
public abstract class Trackable: IWithTrackable | |||||
{ | { | ||||
/// <summary> | |||||
/// Corresponding to tensorflow/python/trackable/constants.py | |||||
/// </summary> | |||||
public static class Constants | |||||
{ | |||||
public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; | |||||
public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; | |||||
public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; | |||||
} | |||||
protected int _self_update_uid; | protected int _self_update_uid; | ||||
protected IDictionary<string, Trackable> _unconditional_dependency_names; | |||||
protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||||
protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||||
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
private bool _manual_tracking = true; | |||||
private static Trackable _none = new AutoTrackable(); | |||||
/// <summary> | |||||
/// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. | |||||
/// The `None` can be any object that inherits `Trackable`. | |||||
/// This Property is supposed to be used only internal. | |||||
/// </summary> | |||||
public static Trackable None | |||||
{ | |||||
get | |||||
{ | |||||
return _none; | |||||
} | |||||
} | |||||
public Trackable GetTrackable() | |||||
{ | |||||
return this; | |||||
} | |||||
public virtual string ObjectIdentifier | |||||
{ | |||||
get => "_generic_user_object"; | |||||
} | |||||
public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } | |||||
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||||
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||||
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||||
/// <summary> | /// <summary> | ||||
/// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
@@ -47,9 +97,13 @@ namespace Tensorflow.Train | |||||
// assign again. It will add this variable to our dependencies, and if there | // assign again. It will add this variable to our dependencies, and if there | ||||
// is a non-trivial restoration queued, it will handle that. This also | // is a non-trivial restoration queued, it will handle that. This also | ||||
// handles slot variables. | // handles slot variables. | ||||
if (!args.Overwrite || new_variable is RefVariable) | |||||
return _track_checkpointable(new_variable, name: args.Name, | |||||
overwrite: args.Overwrite); | |||||
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | |||||
{ | |||||
var temp = new_variable as Trackable; | |||||
var res = _track_trackable(temp, args.Name, args.Overwrite); | |||||
Debug.Assert(res is IVariableV1); | |||||
return res as IVariableV1; | |||||
} | |||||
else | else | ||||
return new_variable; | return new_variable; | ||||
} | } | ||||
@@ -73,10 +127,136 @@ namespace Tensorflow.Train | |||||
/// <summary> | /// <summary> | ||||
/// Initialize dependency management. | /// Initialize dependency management. | ||||
/// </summary> | /// </summary> | ||||
protected void _maybe_initialize_trackable() | |||||
public void _maybe_initialize_trackable() | |||||
{ | { | ||||
// _self_unconditional_checkpoint_dependencies = [] | |||||
if(_unconditional_checkpoint_dependencies is not null) | |||||
{ | |||||
return; | |||||
} | |||||
_self_update_uid = -1; | _self_update_uid = -1; | ||||
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||||
_unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||||
} | |||||
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||||
{ | |||||
_maybe_initialize_trackable(); | |||||
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||||
} | |||||
public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) | |||||
{ | |||||
_maybe_initialize_trackable(); | |||||
if (!_manual_tracking) return trackable; | |||||
var new_reference = new TrackableReference(name, trackable); | |||||
var current_object = _lookup_dependency(name); | |||||
if(current_object is null) | |||||
{ | |||||
_unconditional_checkpoint_dependencies.Add(new_reference); | |||||
_handle_deferred_dependencies(name, trackable); | |||||
} | |||||
_unconditional_dependency_names[name] = trackable; | |||||
return trackable; | |||||
} | |||||
/// <summary> | |||||
/// Pop and load any deferred checkpoint restores into `trackable`. | |||||
/// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for | |||||
/// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are | |||||
/// considered fulfilled and so are deleted. | |||||
/// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. | |||||
/// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, | |||||
/// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context | |||||
/// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). | |||||
/// </summary> | |||||
/// <param name="name"></param> | |||||
/// <param name="trackable"></param> | |||||
public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | |||||
{ | |||||
//_maybe_initialize_trackable(); | |||||
//trackable._maybe_initialize_trackable(); | |||||
// TODO: complete the implementation. | |||||
} | |||||
public virtual Trackable? _lookup_dependency(string name) | |||||
{ | |||||
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; | |||||
else return null; | |||||
} | |||||
public static Trackable convert_to_trackable(object obj, object? parent = null) | |||||
{ | |||||
if (obj is Trackable) | |||||
{ | |||||
return (Trackable)obj; | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children) | |||||
{ | |||||
return new Dictionary<string, Trackable>(); | |||||
} | |||||
public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources( | |||||
SaveOptions? save_options) | |||||
{ | |||||
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | |||||
} | |||||
public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable> object_map, | |||||
IDictionary<Tensor, Tensor> tensor_map, SaveOptions? options = null) | |||||
{ | |||||
var (self_object_map, self_tensor_map) = map_resources(options); | |||||
foreach (var pair in self_object_map) | |||||
{ | |||||
object_map.Add(pair); | |||||
} | |||||
foreach (var pair in self_tensor_map) | |||||
{ | |||||
tensor_map.Add(pair); | |||||
} | |||||
return self_tensor_map.Keys.ToList(); | |||||
} | |||||
public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
{ | |||||
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||||
{ | |||||
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return _self_saveable_object_factories; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` | |||||
/// if you are defining a custom resource or variable with custom ops. | |||||
/// Otherwise, please store the state of your trackable in `tf.Variable` objects | |||||
/// and add them to Trackable object hierarchy using `setattr` (for subclasses | |||||
/// of `AutoTrackable`) or overriding the `_trackable_children` method. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
/// <exception cref="NotImplementedException"></exception> | |||||
public virtual IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | } | ||||
} | } | ||||
public record class TrackableReference(string Name, Trackable Refer); | |||||
} | } |
@@ -0,0 +1,172 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Exceptions; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Training; | |||||
public static class TrackableUtils | |||||
{ | |||||
public class CyclicDependencyError: System.Exception | |||||
{ | |||||
public IDictionary<int, IEnumerable<int>> LeftOverDependencyMap { get; } | |||||
public CyclicDependencyError(IDictionary<int, IEnumerable<int>> leftover_dependency_map): base() | |||||
{ | |||||
LeftOverDependencyMap = leftover_dependency_map; | |||||
} | |||||
public CyclicDependencyError(IDictionary<int, List<int>> leftover_dependency_map): base() | |||||
{ | |||||
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||||
} | |||||
} | |||||
private static string _ESCAPE_CHAR = "."; | |||||
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||||
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||||
{ | |||||
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | |||||
} | |||||
public static string escape_local_name(string name) | |||||
{ | |||||
return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); | |||||
} | |||||
public static string checkpoint_key(string object_path, string local_name) | |||||
{ | |||||
var key_suffix = escape_local_name(local_name); | |||||
if (local_name == SERIALIZE_TO_TENSORS_NAME) | |||||
{ | |||||
key_suffix = ""; | |||||
} | |||||
return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; | |||||
} | |||||
/// <summary> | |||||
/// Topologically sorts the keys of a map so that dependencies appear first. | |||||
/// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm | |||||
/// </summary> | |||||
/// <param name="dependency_map"></param> | |||||
/// <exception cref="ValueError"></exception> | |||||
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||||
{ | |||||
Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | |||||
foreach (var pair in dependency_map) | |||||
{ | |||||
foreach (var dep in pair.Value) | |||||
{ | |||||
if (reverse_dependency_map.ContainsKey(dep)) | |||||
{ | |||||
reverse_dependency_map[dep].Add(pair.Key); | |||||
} | |||||
else | |||||
{ | |||||
reverse_dependency_map[dep] = new HashSet<int>(); | |||||
reverse_dependency_map[dep].Add(pair.Key); | |||||
} | |||||
} | |||||
} | |||||
// Validate that all values in the dependency map are also keys. | |||||
var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); | |||||
if (unknown_keys.Count() > 0) | |||||
{ | |||||
throw new ValueError( | |||||
$"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); | |||||
} | |||||
// Generate the list sorted by objects without dependencies -> dependencies. | |||||
// The returned list will reverse this. | |||||
List<int> reversed_dependency_arr = new(); | |||||
Queue<int> to_visit = new(); | |||||
foreach (var x in dependency_map.Keys) | |||||
{ | |||||
if (!reverse_dependency_map.ContainsKey(x)) | |||||
{ | |||||
to_visit.Enqueue(x); | |||||
} | |||||
} | |||||
while (to_visit.Count > 0) | |||||
{ | |||||
var x = to_visit.Dequeue(); | |||||
reversed_dependency_arr.Add(x); | |||||
foreach (var dep in dependency_map[x].Distinct()) | |||||
{ | |||||
var edges = reverse_dependency_map[dep]; | |||||
edges.Remove(x); | |||||
if (edges.Count == 0) | |||||
{ | |||||
to_visit.Enqueue(dep); | |||||
if (!reverse_dependency_map.Remove(dep)) | |||||
{ | |||||
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (reverse_dependency_map.Count > 0) | |||||
{ | |||||
Dictionary<int, List<int>> leftover_dependency_map = new(); | |||||
foreach (var pair in reverse_dependency_map) | |||||
{ | |||||
foreach (var x in pair.Value) | |||||
{ | |||||
if (leftover_dependency_map.ContainsKey(x)) | |||||
{ | |||||
leftover_dependency_map[x].Add(pair.Key); | |||||
} | |||||
else | |||||
{ | |||||
leftover_dependency_map[x] = new List<int>() { pair.Key }; | |||||
} | |||||
} | |||||
} | |||||
throw new CyclicDependencyError(leftover_dependency_map); | |||||
} | |||||
reversed_dependency_arr.Reverse(); | |||||
return reversed_dependency_arr; | |||||
} | |||||
public static string pretty_print_node_path(IEnumerable<TrackableReference> paths) | |||||
{ | |||||
if (paths.Count() == 0) | |||||
{ | |||||
return "root object"; | |||||
} | |||||
else | |||||
{ | |||||
return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. | |||||
/// </summary> | |||||
/// <param name="key"></param> | |||||
/// <param name="prefix"></param> | |||||
/// <returns></returns> | |||||
public static string extract_local_name(string key, string? prefix = null) | |||||
{ | |||||
if(prefix is null) | |||||
{ | |||||
prefix = ""; | |||||
} | |||||
var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; | |||||
try | |||||
{ | |||||
return key.Substring(key.IndexOf(search_key) + search_key.Length); | |||||
} | |||||
catch(ArgumentOutOfRangeException) | |||||
{ | |||||
return key; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,370 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.IO.Compression; | |||||
using System.Linq; | |||||
using System.Linq.Expressions; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using Tensorflow.Functions; | |||||
using Tensorflow.Keras; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Operations.Activation; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.ApiDef.Types; | |||||
namespace Tensorflow.Training | |||||
{ | |||||
public class NoDependency | |||||
{ | |||||
public Trackable Value { get; set; } | |||||
public NoDependency(Trackable value) | |||||
{ | |||||
Value = value; | |||||
} | |||||
} | |||||
public abstract class TrackableDataStructure : Trackable | |||||
{ | |||||
private bool _self_trainable; | |||||
private List<IVariableV1> _self_extra_variables; | |||||
public TrackableDataStructure() | |||||
{ | |||||
_self_trainable = true; | |||||
_self_extra_variables = new List<IVariableV1>(); | |||||
} | |||||
public abstract IEnumerable<Trackable> Values { get; } | |||||
public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | |||||
public IEnumerable<ILayer> Layers | |||||
{ | |||||
get | |||||
{ | |||||
List<ILayer> collected = new(); | |||||
foreach(var obj in Values) | |||||
{ | |||||
if(obj is ILayer) | |||||
{ | |||||
collected.Add((ILayer)obj); | |||||
} | |||||
else if(obj is TrackableDataStructure) | |||||
{ | |||||
collected.AddRange((obj as TrackableDataStructure).Layers); | |||||
} | |||||
} | |||||
return collected; | |||||
} | |||||
} | |||||
public IEnumerable<IVariableV1> TrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
if (!_self_trainable) | |||||
{ | |||||
return new List<IVariableV1>(); | |||||
} | |||||
List<IVariableV1> trainable_variables = new(); | |||||
foreach (var obj in Values) | |||||
{ | |||||
// skip the process of `module.Module`. | |||||
if (obj is TrackableDataStructure) | |||||
{ | |||||
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); | |||||
} | |||||
} | |||||
foreach(var v in _self_extra_variables) | |||||
{ | |||||
if (v.Trainable) | |||||
{ | |||||
trainable_variables.Add(v); | |||||
} | |||||
} | |||||
return trainable_variables; | |||||
} | |||||
} | |||||
public IEnumerable<IVariableV1> NonTrainableWeights | |||||
{ | |||||
get | |||||
{ | |||||
var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); | |||||
var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); | |||||
List<IVariableV1> non_trainable_variables = new(); | |||||
foreach(var obj in Values) | |||||
{ | |||||
// skip the process of `module.Module`. | |||||
if (obj is TrackableDataStructure) | |||||
{ | |||||
non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); | |||||
} | |||||
} | |||||
if (!_self_trainable) | |||||
{ | |||||
// Return order is all trainable vars, then all non-trainable vars. | |||||
List<IVariableV1> trainable_variables = new(); | |||||
foreach(var obj in Values) | |||||
{ | |||||
// skip the process of `module.Module`. | |||||
if (obj is TrackableDataStructure) | |||||
{ | |||||
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); | |||||
} | |||||
} | |||||
return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); | |||||
} | |||||
else | |||||
{ | |||||
return non_trainable_variables.concat(non_trainable_extra_variables); | |||||
} | |||||
} | |||||
} | |||||
public IEnumerable<IVariableV1> Weights => TrainableWeights.Concat(NonTrainableWeights); | |||||
public IEnumerable<IVariableV1> TrainableVariables => TrainableWeights; | |||||
public IEnumerable<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||||
public IEnumerable<IVariableV1> Variables => Weights; | |||||
// TODO: `losses` property. | |||||
/// <summary> | |||||
/// Add a dependency on `value`. | |||||
/// </summary> | |||||
/// <param name="value"></param> | |||||
/// <param name="name"></param> | |||||
protected virtual Trackable _track_value(Trackable value, string name) | |||||
{ | |||||
value = sticky_attribute_assignment(this, name, value); | |||||
if(value is IVariableV1) | |||||
{ | |||||
_self_extra_variables.Add(value as IVariableV1); | |||||
} | |||||
// skip the left process (need to be done in the future). | |||||
return value; | |||||
} | |||||
public static Trackable wrap_or_unwrap(NoDependency value) | |||||
{ | |||||
return value.Value; | |||||
} | |||||
public static Trackable wrap_or_unwrap(Trackable value) | |||||
{ | |||||
return value; | |||||
} | |||||
public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||||
{ | |||||
return new ListWrapper(value); | |||||
} | |||||
public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||||
{ | |||||
return new ListWrapper(value.ToList()); | |||||
} | |||||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||||
{ | |||||
value = wrap_or_unwrap(value); | |||||
trackable._track_trackable(value, name, true); | |||||
return value; | |||||
} | |||||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) | |||||
{ | |||||
var wrapped_value = wrap_or_unwrap(value); | |||||
trackable._track_trackable(wrapped_value, name, true); | |||||
return wrapped_value; | |||||
} | |||||
protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value) | |||||
{ | |||||
var wrapped_value = wrap_or_unwrap(value); | |||||
trackable._track_trackable(wrapped_value, name, true); | |||||
return wrapped_value; | |||||
} | |||||
} | |||||
public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable | |||||
{ | |||||
private IList<Trackable> _storage; | |||||
private bool _non_append_mutation_value; | |||||
private bool _external_modification_value; | |||||
private IList<Trackable> _last_wrapped_list_snapshot; | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="wrapped_list">The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be | |||||
/// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | |||||
public ListWrapper(IList<Trackable> wrapped_list) | |||||
{ | |||||
_storage = wrapped_list; | |||||
_non_append_mutation_value = _external_modification_value = false; | |||||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||||
} | |||||
protected bool NonAppendMuation { | |||||
get => _non_append_mutation_value; | |||||
set | |||||
{ | |||||
// TODO: deal with `attribute_sentinel`. | |||||
_non_append_mutation_value = value; | |||||
} | |||||
} | |||||
protected bool ExternalModification | |||||
{ | |||||
get => _external_modification_value; | |||||
set | |||||
{ | |||||
// TODO: deal with `attribute_sentinel`. | |||||
_external_modification_value = value; | |||||
} | |||||
} | |||||
public override IEnumerable<Trackable> Values => this; | |||||
public bool IsReadOnly { get => _storage.IsReadOnly; } | |||||
/// <summary> | |||||
/// Checks for any changes to the wrapped list not through the wrapper. | |||||
/// </summary> | |||||
private void check_external_modification() | |||||
{ | |||||
if (_external_modification_value || _non_append_mutation_value) return; | |||||
if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) | |||||
{ | |||||
_external_modification_value = true; | |||||
} | |||||
} | |||||
private void update_snapshot() | |||||
{ | |||||
// TODO: deal with `attribute_sentinel`. | |||||
if (_external_modification_value || _non_append_mutation_value) return; | |||||
_last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||||
} | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
check_external_modification(); | |||||
if (_non_append_mutation_value) | |||||
{ | |||||
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + | |||||
$", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + | |||||
$"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); | |||||
} | |||||
if (_external_modification_value) | |||||
{ | |||||
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + | |||||
$"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + | |||||
$"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); | |||||
} | |||||
var children = base._trackable_children(save_type, cache); | |||||
if(save_type == SaveType.SAVEDMODEL) | |||||
{ | |||||
children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
return children; | |||||
} | |||||
private bool has_mutation_or_trackable() | |||||
{ | |||||
return _non_append_mutation_value; | |||||
} | |||||
/// <summary> | |||||
/// Allows storage of non-trackable objects. | |||||
/// </summary> | |||||
/// <param name="value"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
protected override Trackable _track_value(Trackable value, string name) | |||||
{ | |||||
try | |||||
{ | |||||
base._track_value(value, name); | |||||
} | |||||
catch(ValueError ex) | |||||
{ | |||||
value = sticky_attribute_assignment(this, name, value); | |||||
} | |||||
return value; | |||||
} | |||||
public object Clone() | |||||
{ | |||||
var res = new ListWrapper(_storage.Select(x => x).ToList()); | |||||
res.NonAppendMuation= _non_append_mutation_value; | |||||
res.ExternalModification = _external_modification_value; | |||||
return res; | |||||
} | |||||
public Trackable this[int index] { | |||||
get => _storage[index]; | |||||
set | |||||
{ | |||||
// skip the process of `Slice`, maybe support it in the future. | |||||
_non_append_mutation_value = true; | |||||
_storage[index] = _track_value(value, _name_element(index)); | |||||
update_snapshot(); | |||||
} | |||||
} | |||||
public int IndexOf(Trackable item) => _storage.IndexOf(item); | |||||
public void Insert(int index, Trackable item) | |||||
{ | |||||
check_external_modification(); | |||||
_non_append_mutation_value = true; | |||||
_storage.Insert(index, item); | |||||
update_snapshot(); | |||||
} | |||||
public void RemoveAt(int index) | |||||
{ | |||||
check_external_modification(); | |||||
if (has_mutation_or_trackable()) | |||||
{ | |||||
_non_append_mutation_value = true; | |||||
} | |||||
_storage.RemoveAt(index); | |||||
update_snapshot(); | |||||
} | |||||
public int Count { get => _storage.Count; } | |||||
public void Add(Trackable item) | |||||
{ | |||||
check_external_modification(); | |||||
_storage.Add(item); | |||||
update_snapshot(); | |||||
} | |||||
public void Clear() => _storage.Clear(); | |||||
public bool Contains(Trackable item) => _storage.Contains(item); | |||||
public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); | |||||
public bool Remove(Trackable item) | |||||
{ | |||||
check_external_modification(); | |||||
if (has_mutation_or_trackable()) | |||||
{ | |||||
_non_append_mutation_value = true; | |||||
} | |||||
var res = _storage.Remove(item); | |||||
update_snapshot(); | |||||
return res; | |||||
} | |||||
public IEnumerator<Trackable> GetEnumerator() => _storage.GetEnumerator(); | |||||
IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); | |||||
protected string _name_element(int index) => $"{index}"; | |||||
} | |||||
} |
@@ -2,14 +2,20 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Collections.Generic; | |||||
using Tensorflow.ModelSaving; | |||||
using System.Diagnostics; | |||||
using Tensorflow.Checkpoint; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class BaseResourceVariable : DisposableObject | |||||
public class BaseResourceVariable : DisposableTrackableObject | |||||
{ | { | ||||
protected string _name; | protected string _name; | ||||
public virtual string Name => _handle_name; | public virtual string Name => _handle_name; | ||||
public virtual string SharedName => _name; | |||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
protected string _handle_name; | protected string _handle_name; | ||||
@@ -19,9 +25,10 @@ namespace Tensorflow | |||||
public string UniqueId => _unique_id; | public string UniqueId => _unique_id; | ||||
protected bool _in_graph_mode; | protected bool _in_graph_mode; | ||||
internal bool InGraphMode => _in_graph_mode; | |||||
protected bool _trainable; | protected bool _trainable; | ||||
public bool trainable => _trainable; | |||||
public bool Trainable => _trainable; | |||||
protected Tensor _initial_value; | protected Tensor _initial_value; | ||||
@@ -46,6 +53,7 @@ namespace Tensorflow | |||||
public Graph Graph => handle.graph; | public Graph Graph => handle.graph; | ||||
public string Device => handle.Device; | public string Device => handle.Device; | ||||
EagerResourceDeleter eager_resource_deleter; | EagerResourceDeleter eager_resource_deleter; | ||||
public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; | |||||
public BaseResourceVariable() | public BaseResourceVariable() | ||||
{ | { | ||||
@@ -73,6 +81,11 @@ namespace Tensorflow | |||||
_handle = handle.EagerTensorHandle.DangerousGetHandle(); | _handle = handle.EagerTensorHandle.DangerousGetHandle(); | ||||
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | ||||
} | } | ||||
else if(handle is null) | |||||
{ | |||||
// TODO: fix this dangerous change. | |||||
_handle = IntPtr.Zero; | |||||
} | |||||
else | else | ||||
{ | { | ||||
_handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); | _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); | ||||
@@ -165,7 +178,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
void variable_accessed(BaseResourceVariable variable) | void variable_accessed(BaseResourceVariable variable) | ||||
{ | { | ||||
if (variable.trainable) | |||||
if (variable.Trainable) | |||||
{ | { | ||||
foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
tape.VariableAccessed(variable as ResourceVariable); | tape.VariableAccessed(variable as ResourceVariable); | ||||
@@ -243,5 +256,60 @@ namespace Tensorflow | |||||
else | else | ||||
return value(); | return value(); | ||||
} | } | ||||
public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options) | |||||
{ | |||||
BaseResourceVariable new_variable; | |||||
if (save_options.experimental_variable_policy.save_variable_devices()) | |||||
{ | |||||
tf.device(this.Device); | |||||
Debug.Assert(this is ResourceVariable); | |||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
} | |||||
else | |||||
{ | |||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
} | |||||
Dictionary<Trackable, Trackable> obj_map = new(); | |||||
Dictionary<Tensor, Tensor> resource_map = new(); | |||||
obj_map[this] = new_variable; | |||||
resource_map[this.handle] = new_variable.handle; | |||||
return (obj_map, resource_map); | |||||
} | |||||
/// <summary> | |||||
/// Writes additional information of the variable into the SavedObject proto. | |||||
/// ubclasses of ResourceVariables could choose to override this method to | |||||
/// customize extra information to provide when saving a SavedModel. | |||||
/// </summary> | |||||
/// <param name="proto"></param> | |||||
/// <param name="options"></param> | |||||
public virtual void write_object_proto(SavedObject proto, SaveOptions options) | |||||
{ | |||||
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||||
} | |||||
public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
{ | |||||
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||||
return res; | |||||
} | |||||
public Tensor is_initialized(string name = null) | |||||
{ | |||||
return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); | |||||
} | |||||
public Tensor read_value_no_copy() | |||||
{ | |||||
Tensor value = null; | |||||
tf_with(ops.name_scope("Read"), _ => | |||||
{ | |||||
// TODO: `no_copy = true`. | |||||
value = _read_variable_op(); | |||||
}); | |||||
return array_ops.identity(value); | |||||
} | |||||
} | } | ||||
} | } |
@@ -46,6 +46,7 @@ namespace Tensorflow | |||||
Graph Graph { get; } | Graph Graph { get; } | ||||
TF_DataType dtype { get; } | TF_DataType dtype { get; } | ||||
Shape shape { get; } | Shape shape { get; } | ||||
bool Trainable { get; } | |||||
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); | IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); | ||||
@@ -20,11 +20,12 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Train; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[Obsolete] | [Obsolete] | ||||
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
public partial class RefVariable: Trackable, IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
{ | { | ||||
protected string _name; | protected string _name; | ||||
public string UniqueId => _name; | public string UniqueId => _name; | ||||
@@ -56,6 +57,7 @@ namespace Tensorflow | |||||
public string Name => _variable.name; | public string Name => _variable.name; | ||||
public Tensor eval() => _variable; | public Tensor eval() => _variable; | ||||
public bool Trainable => _trainable; | |||||
public RefVariable(object initial_value = null, | public RefVariable(object initial_value = null, | ||||
bool trainable = true, | bool trainable = true, | ||||
@@ -17,7 +17,9 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Checkpoint; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -39,6 +41,7 @@ namespace Tensorflow | |||||
VariableAggregation aggregation = VariableAggregation.None, | VariableAggregation aggregation = VariableAggregation.None, | ||||
Shape shape = null) | Shape shape = null) | ||||
{ | { | ||||
Aggregation = aggregation; | |||||
if (variable_def != null) | if (variable_def != null) | ||||
{ | { | ||||
if (initial_value != null) | if (initial_value != null) | ||||
@@ -0,0 +1,70 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Gradients; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Variables | |||||
{ | |||||
/// <summary> | |||||
/// A variable with no initializer. | |||||
/// </summary> | |||||
public sealed class UninitializedVariable: BaseResourceVariable | |||||
{ | |||||
// TODO: complete the arg list. | |||||
public UninitializedVariable( | |||||
bool trainable = true, | |||||
string caching_device = "", | |||||
string name = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
VariableAggregation aggregation = VariableAggregation.None, | |||||
Shape shape = null, | |||||
Tensor extra_handle_data = null) | |||||
{ | |||||
string unique_id = ""; | |||||
string handle_name = ""; | |||||
tf_with(ops.init_scope(), (x) => | |||||
{ | |||||
_in_graph_mode = !tf.Context.executing_eagerly(); | |||||
tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => | |||||
{ | |||||
handle_name = ops.name_from_scope_name(name); | |||||
string? shared_name; | |||||
if (_in_graph_mode) | |||||
{ | |||||
shared_name = handle_name; | |||||
unique_id = shared_name; | |||||
} | |||||
else | |||||
{ | |||||
unique_id = $"{handle_name}-{ops.uid()}"; | |||||
shared_name = null; | |||||
} | |||||
var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||||
shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); | |||||
// skip the assignment of `handle._parent_trackable` because of lack of API. | |||||
// skip the assignment of `handle._name` and `handle._unique_id` because of accessability. | |||||
if (_in_graph_mode) | |||||
{ | |||||
tf_with(ops.name_scope("Read"), _ => | |||||
{ | |||||
tf.device(handle.Device); | |||||
var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
// _maybe_set_handle_data(dtype, handle, value) | |||||
_graph_element = value; | |||||
}); | |||||
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||||
} | |||||
else | |||||
{ | |||||
_graph_element = null; | |||||
} | |||||
}); | |||||
}); | |||||
_shape = shape; | |||||
_dtype = dtype; | |||||
base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); | |||||
} | |||||
} | |||||
} |
@@ -566,5 +566,23 @@ namespace Tensorflow | |||||
else | else | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
public static bool inside_function() | |||||
{ | |||||
return get_default_graph().building_function; | |||||
} | |||||
public static void dismantle_graph(Graph graph) | |||||
{ | |||||
} | |||||
public class NullContextManager: IDisposable | |||||
{ | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public partial class Functional | public partial class Functional | ||||
{ | { | ||||
public ModelConfig get_config() | |||||
public override IKerasConfig get_config() | |||||
{ | { | ||||
return get_network_config(); | return get_network_config(); | ||||
} | } | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
Name = name | Name = name | ||||
}; | }; | ||||
var node_conversion_map = new Dictionary<string, int>(); | var node_conversion_map = new Dictionary<string, int>(); | ||||
foreach (var layer in _self_tracked_trackables) | foreach (var layer in _self_tracked_trackables) | ||||
{ | { | ||||
@@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
var layer_configs = new List<LayerConfig>(); | var layer_configs = new List<LayerConfig>(); | ||||
foreach (var layer in _self_tracked_trackables) | |||||
using (SharedObjectSavingScope.Enter()) | |||||
{ | { | ||||
var filtered_inbound_nodes = new List<NodeConfig>(); | |||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
foreach (var layer in _self_tracked_trackables) | |||||
{ | { | ||||
var node_key = _make_node_key(layer.Name, original_node_index); | |||||
if (NetworkNodes.Contains(node_key) && !node.is_input) | |||||
var filtered_inbound_nodes = new List<NodeConfig>(); | |||||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
{ | { | ||||
var node_data = node.serialize(_make_node_key, node_conversion_map); | |||||
filtered_inbound_nodes.append(node_data); | |||||
var node_key = _make_node_key(layer.Name, original_node_index); | |||||
if (NetworkNodes.Contains(node_key) && !node.is_input) | |||||
{ | |||||
var node_data = node.serialize(_make_node_key, node_conversion_map); | |||||
filtered_inbound_nodes.append(node_data); | |||||
} | |||||
} | } | ||||
} | |||||
var layer_config = generic_utils.serialize_keras_object(layer); | |||||
layer_config.Name = layer.Name; | |||||
layer_config.InboundNodes = filtered_inbound_nodes; | |||||
layer_configs.Add(layer_config); | |||||
var layer_config = generic_utils.serialize_layer_to_config(layer); | |||||
layer_config.Name = layer.Name; | |||||
layer_config.InboundNodes = filtered_inbound_nodes; | |||||
layer_configs.Add(layer_config); | |||||
} | |||||
} | } | ||||
config.Layers = layer_configs; | config.Layers = layer_configs; | ||||
@@ -2,7 +2,9 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
@@ -20,6 +22,30 @@ namespace Tensorflow.Keras.Engine | |||||
Dictionary<long, int> tensor_usage_count; | Dictionary<long, int> tensor_usage_count; | ||||
/// <summary> | |||||
/// Dictionary of layer dependencies to be included in the checkpoint. | |||||
/// </summary> | |||||
public IDictionary<string, ILayer> LayerCheckpointDependencies | |||||
{ | |||||
get | |||||
{ | |||||
int weight_layer_index = 0; | |||||
Dictionary<string, ILayer> dependencies = new(); | |||||
for(int i = 0; i < Layers.Count; i++) | |||||
{ | |||||
var layer = Layers[i]; | |||||
var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); | |||||
if(weights.Count > 0) | |||||
{ | |||||
dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; | |||||
weight_layer_index++; | |||||
} | |||||
dependencies[$"layer-{i}"] = layer; | |||||
} | |||||
return dependencies; | |||||
} | |||||
} | |||||
public Functional(Tensors inputs, Tensors outputs, string name = null) | public Functional(Tensors inputs, Tensors outputs, string name = null) | ||||
: base(new ModelArgs | : base(new ModelArgs | ||||
{ | { | ||||
@@ -44,6 +70,7 @@ namespace Tensorflow.Keras.Engine | |||||
this.inputs = inputs; | this.inputs = inputs; | ||||
this.outputs = outputs; | this.outputs = outputs; | ||||
built = true; | built = true; | ||||
_buildInputShape = inputs.shape; | |||||
if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
@@ -325,5 +352,28 @@ namespace Tensorflow.Keras.Engine | |||||
return output_tensors; | return output_tensors; | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | |||||
.ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
protected override void _init_set_name(string name, bool zero_based = true) | |||||
{ | |||||
if (string.IsNullOrEmpty(name)) | |||||
{ | |||||
string class_name = GetType().Name; | |||||
if (this.GetType() == typeof(Functional)) | |||||
{ | |||||
class_name = "Model"; | |||||
} | |||||
this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); | |||||
} | |||||
else | |||||
{ | |||||
this.name = name; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,32 @@ | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Keras.Engine; | |||||
public abstract partial class Layer | |||||
{ | |||||
public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||||
public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||||
public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
IDictionary<string, Trackable> children; | |||||
if (save_type == SaveType.SAVEDMODEL) | |||||
{ | |||||
Debug.Assert(cache is not null); | |||||
children = TrackableSavedModelSaver.trackable_children(cache); | |||||
} | |||||
else | |||||
{ | |||||
children = new Dictionary<string, Trackable>(); | |||||
} | |||||
return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} |
@@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine | |||||
public bool Built => built; | public bool Built => built; | ||||
public bool Trainable => args.Trainable; | public bool Trainable => args.Trainable; | ||||
public TF_DataType DType => args.DType; | public TF_DataType DType => args.DType; | ||||
public bool AutoCast => args.Autocast; | |||||
public IRegularizer ActivityRegularizer => args.ActivityRegularizer; | |||||
/// <summary> | /// <summary> | ||||
/// A stateful layer is a layer whose updates are run during inference too, | /// A stateful layer is a layer whose updates are run during inference too, | ||||
@@ -59,6 +61,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// Provides information about which inputs are compatible with the layer. | /// Provides information about which inputs are compatible with the layer. | ||||
/// </summary> | /// </summary> | ||||
protected InputSpec inputSpec; | protected InputSpec inputSpec; | ||||
public InputSpec InputSpec => inputSpec; | |||||
bool dynamic = true; | bool dynamic = true; | ||||
public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
@@ -77,6 +80,8 @@ namespace Tensorflow.Keras.Engine | |||||
protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
protected List<Operation> updates; | protected List<Operation> updates; | ||||
public Shape BatchInputShape => args.BatchInputShape; | public Shape BatchInputShape => args.BatchInputShape; | ||||
protected TensorShapeConfig _buildInputShape = null; | |||||
public TensorShapeConfig BuildInputShape => _buildInputShape; | |||||
List<INode> inboundNodes; | List<INode> inboundNodes; | ||||
public List<INode> InboundNodes => inboundNodes; | public List<INode> InboundNodes => inboundNodes; | ||||
@@ -86,9 +91,29 @@ namespace Tensorflow.Keras.Engine | |||||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ||||
public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
public Tensor[] input => inboundNodes[0].input_tensors; | |||||
public Tensor[] input | |||||
{ | |||||
get | |||||
{ | |||||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||||
{ | |||||
return inboundNodes[0].input_tensors; | |||||
} | |||||
return null; | |||||
} | |||||
} | |||||
public Dictionary<int, List<INode>> NodesByDepth { get; set; } | public Dictionary<int, List<INode>> NodesByDepth { get; set; } | ||||
public Shape OutputShape => inboundNodes[0].Outputs.shape; | |||||
public Shape OutputShape | |||||
{ | |||||
get | |||||
{ | |||||
if(inboundNodes is not null && inboundNodes.Count > 0) | |||||
{ | |||||
return inboundNodes[0].Outputs.shape; | |||||
} | |||||
return null; | |||||
} | |||||
} | |||||
protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
@@ -162,7 +187,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
/// <param name="inputs"></param> | /// <param name="inputs"></param> | ||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <param name="is_training"></param> | |||||
/// <param name="training"></param> | |||||
/// <returns></returns> | /// <returns></returns> | ||||
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
{ | { | ||||
@@ -201,6 +226,7 @@ namespace Tensorflow.Keras.Engine | |||||
public virtual void build(Shape input_shape) | public virtual void build(Shape input_shape) | ||||
{ | { | ||||
_buildInputShape = input_shape; | |||||
built = true; | built = true; | ||||
} | } | ||||
@@ -286,7 +312,9 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
public virtual LayerArgs get_config() | |||||
public List<IVariableV1> Variables => weights; | |||||
public virtual IKerasConfig get_config() | |||||
=> args; | => args; | ||||
} | } | ||||
} | } |
@@ -33,6 +33,11 @@ namespace Tensorflow.Keras.Engine | |||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false) | bool use_multiprocessing = false) | ||||
{ | { | ||||
if (x.dims[0] != y.dims[0]) | |||||
{ | |||||
throw new InvalidArgumentError( | |||||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||||
} | |||||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | ||||
var train_x = x[new Slice(0, train_count)]; | var train_x = x[new Slice(0, train_count)]; | ||||
var train_y = y[new Slice(0, train_count)]; | var train_y = y[new Slice(0, train_count)]; | ||||
@@ -1,5 +1,8 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
@@ -18,9 +21,21 @@ namespace Tensorflow.Keras.Engine | |||||
bool overwrite = true, | bool overwrite = true, | ||||
bool include_optimizer = true, | bool include_optimizer = true, | ||||
string save_format = "tf", | string save_format = "tf", | ||||
SaveOptions options = null) | |||||
SaveOptions? options = null, | |||||
ConcreteFunction? signatures = null, | |||||
bool save_traces = true) | |||||
{ | { | ||||
saver.save(this, filepath); | |||||
if (save_format != "pb") | |||||
{ | |||||
saver.save(this, filepath); | |||||
} | |||||
else | |||||
{ | |||||
using (SharedObjectSavingScope.Enter()) | |||||
{ | |||||
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -4,6 +4,8 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.Saving.SavedModel; | |||||
using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -34,6 +36,13 @@ namespace Tensorflow.Keras.Engine | |||||
IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
bool _base_model_initialized; | bool _base_model_initialized; | ||||
bool stop_training; | bool stop_training; | ||||
DataHandler data_handler; | |||||
public OptimizerV2 Optimizer | |||||
{ | |||||
get => optimizer; | |||||
set => optimizer = value; | |||||
} | |||||
public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
: base(args) | : base(args) | ||||
@@ -101,5 +110,15 @@ namespace Tensorflow.Keras.Engine | |||||
return variables; | return variables; | ||||
} | } | ||||
} | } | ||||
public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
{ | |||||
if(save_type == SaveType.SAVEDMODEL) | |||||
{ | |||||
//TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. | |||||
} | |||||
var children = base._trackable_children(save_type, cache); | |||||
return children; | |||||
} | |||||
} | } | ||||
} | } |
@@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||||
{ | { | ||||
throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
} | } | ||||
_buildInputShape = input_shape; | |||||
built = true; | built = true; | ||||
} | } | ||||
@@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||||
} | } | ||||
public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
{ | { | ||||
_buildInputShape = input_shape; | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
@@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers { | |||||
// SELU has no arguments | // SELU has no arguments | ||||
} | } | ||||
public override void build(Shape input_shape) { | public override void build(Shape input_shape) { | ||||
if ( alpha < 0f ) { | |||||
throw new ValueError("Alpha must be a number greater than 0."); | |||||
} | |||||
built = true; | |||||
if ( alpha < 0f ) { | |||||
throw new ValueError("Alpha must be a number greater than 0."); | |||||
} | |||||
_buildInputShape = input_shape; | |||||
built = true; | |||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
@@ -4,6 +4,7 @@ using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers | |||||
return scores; | return scores; | ||||
} | } | ||||
public override LayerArgs get_config() => this.args; | |||||
public override IKerasConfig get_config() => this.args; | |||||
//var config = new Dictionary<object, object> { | //var config = new Dictionary<object, object> { | ||||
// { | // { | ||||
// "use_scale", | // "use_scale", | ||||