Fix the error of loading model saved before tf2.5.tags/v0.100.5-BERT-load
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public SavedModelAPI saved_model { get; } = new SavedModelAPI(); | |||||
} | |||||
public class SavedModelAPI | |||||
{ | |||||
public Trackable load(string export_dir, LoadOptions? options = null) | |||||
{ | |||||
return Loader.load(export_dir, options); | |||||
} | |||||
} | |||||
} |
@@ -8,6 +8,7 @@ using Tensorflow.Exceptions; | |||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -181,7 +182,7 @@ public class FuncGraph : Graph, IDisposable | |||||
const int _EAGER_CONST_THRESHOLD = 128; | const int _EAGER_CONST_THRESHOLD = 128; | ||||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | ||||
{ | { | ||||
if(tensor is EagerTensor) | |||||
if(tensor is EagerTensor or NDArray) | |||||
{ | { | ||||
if (name == null) | if (name == null) | ||||
name = ops.uid().ToString(); | name = ops.uid().ToString(); | ||||
@@ -10,4 +10,5 @@ public interface IOptimizer | |||||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | ||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true); | bool experimental_aggregate_gradients = true); | ||||
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); | |||||
} | } |
@@ -216,10 +216,12 @@ namespace Tensorflow | |||||
public virtual object get_attr(string name) | public virtual object get_attr(string name) | ||||
{ | { | ||||
var buf = new Buffer(); | var buf = new Buffer(); | ||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||||
tf.Status.Check(true); | |||||
Status status = new(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
status.Check(true); | |||||
var tf_buffer = c_api.TF_GetBuffer(buf); | |||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||||
var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>()); | |||||
var oneof_value = x.ValueCase; | var oneof_value = x.ValueCase; | ||||
if (oneof_value == AttrValue.ValueOneofCase.None) | if (oneof_value == AttrValue.ValueOneofCase.None) | ||||
@@ -64,36 +64,68 @@ namespace Tensorflow | |||||
var num_elements = shape.size; | var num_elements = shape.size; | ||||
var tensor_dtype = tensor.Dtype.as_tf_dtype(); | var tensor_dtype = tensor.Dtype.as_tf_dtype(); | ||||
T[] ExpandArrayToSize<T>(IList<T> src) | |||||
{ | |||||
if(src.Count == 0) | |||||
{ | |||||
return new T[0]; | |||||
} | |||||
var pad_count = num_elements - src.Count; | |||||
var pre = pad_count / 2; | |||||
var after = pad_count - pre; | |||||
var first_elem = src[0]; | |||||
var last_elem = src[src.Count - 1]; | |||||
T[] res = new T[num_elements]; | |||||
for(long i = 0; i < num_elements; i++) | |||||
{ | |||||
if (i < pre) res[i] = first_elem; | |||||
else if (i >= num_elements - after) res[i] = last_elem; | |||||
else res[i] = src[(int)(i - pre)]; | |||||
} | |||||
return res; | |||||
} | |||||
if (shape.ndim > 0 && tensor.TensorContent.Length > 0) | if (shape.ndim > 0 && tensor.TensorContent.Length > 0) | ||||
{ | { | ||||
return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); | return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); | ||||
} | } | ||||
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||||
NDArray values; | |||||
if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||||
{ | { | ||||
return np.array(tensor.HalfVal.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.HalfVal)); | |||||
} | } | ||||
else if (tensor.Dtype == DataType.DtFloat) | else if (tensor.Dtype == DataType.DtFloat) | ||||
{ | { | ||||
return np.array(tensor.FloatVal.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.FloatVal)); | |||||
} | } | ||||
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) | else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) | ||||
{ | { | ||||
return np.array(tensor.IntVal.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.IntVal)); | |||||
} | } | ||||
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | ||||
{ | { | ||||
return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.Int64Val)); | |||||
} | } | ||||
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | ||||
{ | { | ||||
return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.Uint64Val)); | |||||
} | } | ||||
else if (tensor.Dtype == DataType.DtBool) | else if (tensor.Dtype == DataType.DtBool) | ||||
{ | { | ||||
return np.array(tensor.BoolVal.ToArray()).reshape(shape); | |||||
values = np.array(ExpandArrayToSize(tensor.BoolVal)); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " + | |||||
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); | |||||
} | |||||
if(values.size == 0) | |||||
{ | |||||
return np.zeros(shape, tensor_dtype); | |||||
} | } | ||||
throw new NotImplementedException("MakeNdarray"); | |||||
return values.reshape(shape); | |||||
} | } | ||||
private static readonly TF_DataType[] quantized_types = new TF_DataType[] | private static readonly TF_DataType[] quantized_types = new TF_DataType[] | ||||
@@ -1,5 +1,6 @@ | |||||
using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Trackables; | namespace Tensorflow.Trackables; | ||||
@@ -11,12 +12,23 @@ public class TrackableConstant : Trackable | |||||
_constant = constant; | _constant = constant; | ||||
} | } | ||||
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||||
Dictionary<string, MapField<string, AttrValue>> operation_attributes) | Dictionary<string, MapField<string, AttrValue>> operation_attributes) | ||||
{ | { | ||||
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | ||||
var ndarray = tensor_util.MakeNdarray(tensor_proto); | var ndarray = tensor_util.MakeNdarray(tensor_proto); | ||||
var imported_constant = constant_op.constant(ndarray); | |||||
return (new TrackableConstant(imported_constant), null); | |||||
Tensor imported_constant; | |||||
if (tensor_proto.Dtype == DataType.DtString) | |||||
{ | |||||
imported_constant = tf_with(ops.device("CPU"), _ => | |||||
{ | |||||
return constant_op.constant(ndarray); | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
imported_constant = constant_op.constant(ndarray); | |||||
} | |||||
return (imported_constant, null); | |||||
} | } | ||||
} | } |
@@ -46,4 +46,9 @@ public class RevivedTypes | |||||
return (null, null); | return (null, null); | ||||
} | } | ||||
} | } | ||||
public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj) | |||||
{ | |||||
_registered_revived_creator[identifier] = obj; | |||||
} | |||||
} | } |
@@ -137,7 +137,7 @@ public class SaveableView | |||||
/// </summary> | /// </summary> | ||||
public List<int> dependency_sorted_node_ids() | public List<int> dependency_sorted_node_ids() | ||||
{ | { | ||||
Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||||
Dictionary<int, List<int>> dependency_map = new(); | |||||
foreach (var node in _nodes) | foreach (var node in _nodes) | ||||
{ | { | ||||
var node_id = _node_ids[node]; | var node_id = _node_ids[node]; | ||||
@@ -116,17 +116,23 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
} | } | ||||
Dictionary<string, ConcreteFunction> loaded_gradients = new(); | Dictionary<string, ConcreteFunction> loaded_gradients = new(); | ||||
foreach (var fdef in _sort_function_defs(library, function_deps)) | |||||
// Debug(Rinne) | |||||
var temp = _sort_function_defs(library, function_deps); | |||||
int i = 0; | |||||
foreach (var fdef in temp) | |||||
{ | { | ||||
i++; | |||||
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | ||||
object structured_input_signature = null; | object structured_input_signature = null; | ||||
object structured_outputs = null; | object structured_outputs = null; | ||||
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | ||||
{ | { | ||||
var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||||
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||||
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||||
// TODO(Rinne): deal with structured_input_signature and structured_outputs. | |||||
//var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||||
//structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||||
//structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||||
} | } | ||||
graph.as_default(); | graph.as_default(); | ||||
@@ -234,27 +240,41 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | ||||
{ | { | ||||
foreach(var op in func_graph.get_operations()) | |||||
if(loaded_gradients is null || loaded_gradients.Count == 0) | |||||
{ | { | ||||
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||||
{ | |||||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||||
op.op._gradient_function = function._get_gradient_function(); | |||||
} | |||||
string gradient_op_type = null; | |||||
try | |||||
{ | |||||
gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||||
} | |||||
catch(InvalidArgumentError) | |||||
foreach (var op in func_graph.get_operations()) | |||||
{ | { | ||||
continue; | |||||
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||||
{ | |||||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||||
op.op._gradient_function = function._get_gradient_function(); | |||||
} | |||||
} | } | ||||
if (loaded_gradients.ContainsKey(gradient_op_type)) | |||||
} | |||||
else | |||||
{ | |||||
foreach (var op in func_graph.get_operations()) | |||||
{ | { | ||||
var grad_fn = loaded_gradients[gradient_op_type]; | |||||
grad_fn.NumPositionArgs = op.op.inputs.Length; | |||||
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||||
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||||
{ | |||||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||||
op.op._gradient_function = function._get_gradient_function(); | |||||
} | |||||
string gradient_op_type = null; | |||||
try | |||||
{ | |||||
gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||||
} | |||||
catch (InvalidArgumentError) | |||||
{ | |||||
continue; | |||||
} | |||||
if (loaded_gradients.ContainsKey(gradient_op_type)) | |||||
{ | |||||
var grad_fn = loaded_gradients[gradient_op_type]; | |||||
grad_fn.NumPositionArgs = op.op.inputs.Length; | |||||
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -15,6 +15,7 @@ using Tensorflow.Functions; | |||||
using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
using Tensorflow.Trackables; | using Tensorflow.Trackables; | ||||
using OneOf; | using OneOf; | ||||
using Tensorflow.Keras.Engine; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -34,7 +35,7 @@ namespace Tensorflow | |||||
private List<int>? _filtered_nodes; | private List<int>? _filtered_nodes; | ||||
private List<int> _ordered_node_ids; | private List<int> _ordered_node_ids; | ||||
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | ||||
private List<Trackable> _nodes; | |||||
private List<object> _nodes; | |||||
private Dictionary<int, Action<object, object, object>> _node_setters; | private Dictionary<int, Action<object, object, object>> _node_setters; | ||||
private Dictionary<string, ConcreteFunction> _concrete_functions; | private Dictionary<string, ConcreteFunction> _concrete_functions; | ||||
private HashSet<string> _restored_concrete_functions; | private HashSet<string> _restored_concrete_functions; | ||||
@@ -213,7 +214,13 @@ namespace Tensorflow | |||||
continue; | continue; | ||||
} | } | ||||
var proto = _proto.Nodes[node_id]; | var proto = _proto.Nodes[node_id]; | ||||
foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||||
if(node_id == 10522) | |||||
{ | |||||
// Debug(Rinne) | |||||
Console.WriteLine(); | |||||
} | |||||
var temp = _get_node_dependencies(proto); | |||||
foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) | |||||
{ | { | ||||
deps.Add(dep); | deps.Add(dep); | ||||
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | ||||
@@ -232,7 +239,7 @@ namespace Tensorflow | |||||
// The optimizer and original variable must be created before the slot | // The optimizer and original variable must be created before the slot | ||||
// variable, since the slot variable is generated using the Optimizer's | // variable, since the slot variable is generated using the Optimizer's | ||||
// add_slot API. | // add_slot API. | ||||
var slot_deps = dependency_map[slot_variable_node_id]; | |||||
var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List<int>()); | |||||
slot_deps.Add(node_id); | slot_deps.Add(node_id); | ||||
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | ||||
@@ -245,7 +252,12 @@ namespace Tensorflow | |||||
} | } | ||||
try | try | ||||
{ | { | ||||
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||||
int total = 0; | |||||
foreach(var v in dependency_map.Values) | |||||
{ | |||||
total += v.Count; | |||||
} | |||||
return TrackableUtils.order_by_dependency(dependency_map); | |||||
} | } | ||||
catch (TrackableUtils.CyclicDependencyError ex) | catch (TrackableUtils.CyclicDependencyError ex) | ||||
{ | { | ||||
@@ -339,9 +351,20 @@ namespace Tensorflow | |||||
var saveable_object_proto = item.Value; | var saveable_object_proto = item.Value; | ||||
var save_fn_id = saveable_object_proto.SaveFunction; | var save_fn_id = saveable_object_proto.SaveFunction; | ||||
var restore_fn_id = saveable_object_proto.RestoreFunction; | var restore_fn_id = saveable_object_proto.RestoreFunction; | ||||
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||||
saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id)); | |||||
} | |||||
var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||||
if (saveable_objects is not null && saveable_objects.Count > 0) | |||||
{ | |||||
if(node is Trackable trackable) | |||||
{ | |||||
trackable.SelfSaveableObjectFactories = saveable_objects; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError(); | |||||
} | |||||
} | } | ||||
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -379,12 +402,12 @@ namespace Tensorflow | |||||
{ | { | ||||
// Use the public Optimizer interface when creating slot variables. | // Use the public Optimizer interface when creating slot variables. | ||||
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | ||||
var optimizer_object = nodes[optimizer_node_id]; | |||||
var optimizer_object = nodes[optimizer_node_id] as IOptimizer; | |||||
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | ||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||||
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName); | |||||
nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable; | |||||
node_setters[slot_variable_proto.SlotVariableNodeId] = setattr; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -398,7 +421,7 @@ namespace Tensorflow | |||||
{ | { | ||||
nodes[0] = _recreate_base_user_object().Item1; | nodes[0] = _recreate_base_user_object().Item1; | ||||
} | } | ||||
_nodes = new List<Trackable>(); | |||||
_nodes = new List<object>(); | |||||
for(int i = 0; i < _proto.Nodes.Count; i++) | for(int i = 0; i < _proto.Nodes.Count; i++) | ||||
{ | { | ||||
_nodes.Add(nodes[i]); | _nodes.Add(nodes[i]); | ||||
@@ -412,7 +435,7 @@ namespace Tensorflow | |||||
private void _restore_checkpoint() | private void _restore_checkpoint() | ||||
{ | { | ||||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | var variables_path = SavedModelUtils.get_variables_path(_export_dir); | ||||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||||
var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0))); | |||||
tf_with(ops.device("CPU"), _ => | tf_with(ops.device("CPU"), _ => | ||||
{ | { | ||||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | ||||
@@ -467,7 +490,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes) | |||||
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, object> nodes) | |||||
{ | { | ||||
if (_restored_concrete_functions.Contains(concrete_function_name)) | if (_restored_concrete_functions.Contains(concrete_function_name)) | ||||
{ | { | ||||
@@ -485,12 +508,12 @@ namespace Tensorflow | |||||
// TODO: implement it with concrete functions. | // TODO: implement it with concrete functions. | ||||
} | } | ||||
public Trackable get(int node_id) | |||||
public object get(int node_id) | |||||
{ | { | ||||
return _nodes[node_id]; | return _nodes[node_id]; | ||||
} | } | ||||
public Trackable get(string node_id) | |||||
public object get(string node_id) | |||||
{ | { | ||||
return get(_node_path_to_id[node_id]); | return get(_node_path_to_id[node_id]); | ||||
} | } | ||||
@@ -512,9 +535,9 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||||
private (Dictionary<int, object>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||||
{ | { | ||||
Dictionary<int, Trackable> nodes = new(); | |||||
Dictionary<int, object> nodes = new(); | |||||
Dictionary<int, Action<object, object, object>> node_setters = new(); | Dictionary<int, Action<object, object, object>> node_setters = new(); | ||||
foreach(var item in _loaded_nodes) | foreach(var item in _loaded_nodes) | ||||
{ | { | ||||
@@ -534,10 +557,10 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||||
private (object, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, object> nodes) | |||||
{ | { | ||||
// skip the registered classes. | // skip the registered classes. | ||||
Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||||
Dictionary<OneOf<string, int>, object> dependencies = new(); | |||||
foreach(var item in _get_node_dependencies(proto)) | foreach(var item in _get_node_dependencies(proto)) | ||||
{ | { | ||||
dependencies[item.Key] = nodes[item.Value]; | dependencies[item.Key] = nodes[item.Value]; | ||||
@@ -558,7 +581,7 @@ namespace Tensorflow | |||||
/// <param name="proto"></param> | /// <param name="proto"></param> | ||||
/// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
/// <param name="dependencies"></param> | /// <param name="dependencies"></param> | ||||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, object> dependencies) | |||||
{ | { | ||||
return proto.KindCase switch | return proto.KindCase switch | ||||
{ | { | ||||
@@ -626,7 +649,7 @@ namespace Tensorflow | |||||
} | } | ||||
private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | ||||
IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
IDictionary<OneOf<string, int>, object> dependencies) | |||||
{ | { | ||||
var fn = function_deserialization.recreate_function(proto, _concrete_functions); | var fn = function_deserialization.recreate_function(proto, _concrete_functions); | ||||
foreach (var name in proto.ConcreteFunctions) | foreach (var name in proto.ConcreteFunctions) | ||||
@@ -637,7 +660,7 @@ namespace Tensorflow | |||||
} | } | ||||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||
IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
IDictionary<OneOf<string, int>, object> dependencies) | |||||
{ | { | ||||
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | ||||
_setup_function_captures(proto.ConcreteFunctionName, dependencies); | _setup_function_captures(proto.ConcreteFunctionName, dependencies); | ||||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||||
tf_with(ops.init_scope(), x => | tf_with(ops.init_scope(), x => | ||||
{ | { | ||||
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | ||||
root = loader.get(0); | |||||
root = (Trackable)loader.get(0); | |||||
// skip the assignment of `graph_debug_info`. | // skip the assignment of `graph_debug_info`. | ||||
}); | }); | ||||
// skip the assignment of `tensorflow_version` | // skip the assignment of `tensorflow_version` | ||||
@@ -99,7 +99,7 @@ namespace Tensorflow | |||||
} | } | ||||
if(filters != null && filters.Count > 0) | if(filters != null && filters.Count > 0) | ||||
{ | { | ||||
return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||||
return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x)); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -52,7 +52,7 @@ public static class TrackableUtils | |||||
/// </summary> | /// </summary> | ||||
/// <param name="dependency_map"></param> | /// <param name="dependency_map"></param> | ||||
/// <exception cref="ValueError"></exception> | /// <exception cref="ValueError"></exception> | ||||
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||||
public static List<int> order_by_dependency(IDictionary<int, List<int>> dependency_map) | |||||
{ | { | ||||
Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | ||||
foreach (var pair in dependency_map) | foreach (var pair in dependency_map) | ||||
@@ -102,7 +102,7 @@ public static class TrackableUtils | |||||
edges.Remove(x); | edges.Remove(x); | ||||
if (edges.Count == 0) | if (edges.Count == 0) | ||||
{ | { | ||||
to_visit.Enqueue(dep); | |||||
to_visit.Enqueue(dep); | |||||
if (!reverse_dependency_map.Remove(dep)) | if (!reverse_dependency_map.Remove(dep)) | ||||
{ | { | ||||
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | ||||
@@ -333,5 +333,23 @@ namespace Tensorflow | |||||
}); | }); | ||||
return array_ops.identity(value); | return array_ops.identity(value); | ||||
} | } | ||||
//public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; | |||||
//public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; | |||||
//public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; | |||||
//public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); | |||||
//public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; | |||||
//public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; | |||||
//public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; | |||||
//public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; | |||||
//public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); | |||||
//public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); | |||||
//public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; | |||||
//public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; | |||||
//public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; | |||||
//public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; | |||||
} | } | ||||
} | } |
@@ -1,19 +1,6 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -169,6 +169,12 @@ namespace Tensorflow.Keras | |||||
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | ||||
} | } | ||||
public void set_value(IVariableV1 x, object value) | |||||
{ | |||||
// TODO(Rinne): check the implementation. | |||||
x.assign(value); | |||||
} | |||||
public void batch_set_value(List<(IVariableV1, NDArray)> tuples) | public void batch_set_value(List<(IVariableV1, NDArray)> tuples) | ||||
{ | { | ||||
if (ops.executing_eagerly_outside_functions()) | if (ops.executing_eagerly_outside_functions()) | ||||
@@ -36,6 +36,11 @@ namespace Tensorflow.Keras | |||||
} | } | ||||
} | } | ||||
static KerasInterface() | |||||
{ | |||||
RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer()); | |||||
} | |||||
public KerasDataset datasets { get; } = new KerasDataset(); | public KerasDataset datasets { get; } = new KerasDataset(); | ||||
public IInitializersApi initializers { get; } = new InitializersApi(); | public IInitializersApi initializers { get; } = new InitializersApi(); | ||||
public Regularizers regularizers { get; } = new Regularizers(); | public Regularizers regularizers { get; } = new Regularizers(); | ||||
@@ -14,11 +14,11 @@ namespace Tensorflow.Keras.Optimizers | |||||
protected bool _hypers_created; | protected bool _hypers_created; | ||||
protected virtual string _name { get; } | protected virtual string _name { get; } | ||||
IVariableV1 _iterations; | |||||
protected IVariableV1 _iterations; | |||||
protected ResourceVariable iterations => _iterations as ResourceVariable; | protected ResourceVariable iterations => _iterations as ResourceVariable; | ||||
List<IVariableV1> _weights; | List<IVariableV1> _weights; | ||||
Dictionary<string, float> _hyper; | |||||
Dictionary<string, IVariableV1> _hyper_variables; | |||||
protected Dictionary<string, float> _hyper; | |||||
protected Dictionary<string, IVariableV1> _hyper_variables; | |||||
protected bool _momentum; | protected bool _momentum; | ||||
protected float _initial_decay = 0.0f; | protected float _initial_decay = 0.0f; | ||||
protected bool _use_locking = true; | protected bool _use_locking = true; | ||||
@@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
} | } | ||||
} | } | ||||
protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||||
public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||||
{ | { | ||||
if (initializer == null) | if (initializer == null) | ||||
initializer = tf.zeros_initializer; | initializer = tf.zeros_initializer; | ||||
@@ -0,0 +1,63 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
namespace Tensorflow.Keras.Optimizers | |||||
{ | |||||
public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig | |||||
{ | |||||
public String Identifier { get; } = "optimizer"; | |||||
public int Version { get; } = 2; | |||||
public int MinConsumerVersion { get; } = 1; | |||||
public int MinProducerVersion { get; } = 1; | |||||
public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" }) | |||||
{ | |||||
_hypers_created = true; | |||||
} | |||||
public IKerasConfig get_config() | |||||
{ | |||||
throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " + | |||||
"supported. Please file a feature request if this limitation bothers you."); | |||||
} | |||||
public void SetValue(object name, object value) | |||||
{ | |||||
if(name is not String str) | |||||
{ | |||||
throw new TypeError($"The name of value to set must be string, but got {name.GetType()}"); | |||||
} | |||||
if(value is Trackable trackable) | |||||
{ | |||||
_track_trackable(trackable, str, overwrite: true); | |||||
} | |||||
if(value is IVariableV1 resource_variable) | |||||
{ | |||||
if (!_hyper_variables.ContainsKey(str)) | |||||
{ | |||||
_hyper_variables[str] = resource_variable; | |||||
} | |||||
else | |||||
{ | |||||
keras.backend.set_value(resource_variable, value); | |||||
} | |||||
} | |||||
else if (value is float f) | |||||
{ | |||||
_hyper[str] = f; | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
public Trackable FromProto(SavedUserObject proto) | |||||
{ | |||||
return new RestoredOptimizer(); | |||||
} | |||||
} | |||||
} |
@@ -2,6 +2,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
@@ -103,4 +104,13 @@ public class SequentialModelLoad | |||||
classify_model.fit(x, y, batch_size: 4); | classify_model.fit(x, y, batch_size: 4); | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | |||||
public void TestModelBeforeTF2_5() | |||||
{ | |||||
var a = keras.layers; | |||||
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model; | |||||
model.summary(); | |||||
} | |||||
} | } |