|
@@ -13,6 +13,7 @@ using System.Runtime.CompilerServices; |
|
|
using Tensorflow.Variables; |
|
|
using Tensorflow.Variables; |
|
|
using Tensorflow.Functions; |
|
|
using Tensorflow.Functions; |
|
|
using Tensorflow.Training.Saving.SavedModel; |
|
|
using Tensorflow.Training.Saving.SavedModel; |
|
|
|
|
|
using Tensorflow.Trackables; |
|
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
namespace Tensorflow |
|
|
{ |
|
|
{ |
|
@@ -51,9 +52,13 @@ namespace Tensorflow |
|
|
_node_filters = filters; |
|
|
_node_filters = filters; |
|
|
_node_path_to_id = _convert_node_paths_to_ints(); |
|
|
_node_path_to_id = _convert_node_paths_to_ints(); |
|
|
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); |
|
|
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); |
|
|
foreach(var filter in filters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (filters != null) |
|
|
{ |
|
|
{ |
|
|
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; |
|
|
|
|
|
|
|
|
foreach (var filter in filters) |
|
|
|
|
|
{ |
|
|
|
|
|
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
_filtered_nodes = _retrieve_all_filtered_nodes(); |
|
|
_filtered_nodes = _retrieve_all_filtered_nodes(); |
|
@@ -535,7 +540,13 @@ namespace Tensorflow |
|
|
dependencies[item.Key] = nodes[item.Value]; |
|
|
dependencies[item.Key] = nodes[item.Value]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return _recreate_default(proto, node_id, dependencies); |
|
|
|
|
|
|
|
|
return proto.KindCase switch |
|
|
|
|
|
{ |
|
|
|
|
|
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(), |
|
|
|
|
|
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(), |
|
|
|
|
|
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(), |
|
|
|
|
|
_ => _recreate_default(proto, node_id, dependencies) |
|
|
|
|
|
}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
@@ -549,7 +560,7 @@ namespace Tensorflow |
|
|
return proto.KindCase switch |
|
|
return proto.KindCase switch |
|
|
{ |
|
|
{ |
|
|
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), |
|
|
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), |
|
|
SavedObject.KindOneofCase.Function => throw new NotImplementedException(), |
|
|
|
|
|
|
|
|
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), |
|
|
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), |
|
|
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), |
|
|
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), |
|
|
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), |
|
|
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() |
|
|
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() |
|
@@ -609,6 +620,13 @@ namespace Tensorflow |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto, |
|
|
|
|
|
Dictionary<Maybe<string, int>, Trackable> dependencies) |
|
|
|
|
|
{ |
|
|
|
|
|
throw new NotImplementedException(); |
|
|
|
|
|
//var fn = function_deserialization.setup_bare_concrete_function(proto, ) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, |
|
|
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, |
|
|
Dictionary<Maybe<string, int>, Trackable> dependencies) |
|
|
Dictionary<Maybe<string, int>, Trackable> dependencies) |
|
|
{ |
|
|
{ |
|
|