Browse Source

Merge pull request #1032 from AsakusaRinne/master

Fix the error of loading model saved before tf2.5.
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
3d0e2d0220
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 288 additions and 83 deletions
  1. +20
    -0
      src/TensorFlowNET.Core/APIs/tf.saved_model.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
  4. +5
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +40
    -8
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  6. +15
    -3
      src/TensorFlowNET.Core/Trackables/TrackableConstant.cs
  7. +5
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  9. +41
    -21
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  10. +45
    -22
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  11. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs
  12. +2
    -2
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  13. +18
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  14. +3
    -16
      src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs
  15. +6
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  16. +5
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  17. +4
    -4
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  18. +63
    -0
      src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs
  19. +10
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 20
- 0
src/TensorFlowNET.Core/APIs/tf.saved_model.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow
{
public partial class tensorflow
{
public SavedModelAPI saved_model { get; } = new SavedModelAPI();
}

public class SavedModelAPI
{
public Trackable load(string export_dir, LoadOptions? options = null)
{
return Loader.load(export_dir, options);
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Exceptions;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.NumPy;
using Tensorflow.Operations; using Tensorflow.Operations;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -181,7 +182,7 @@ public class FuncGraph : Graph, IDisposable
const int _EAGER_CONST_THRESHOLD = 128; const int _EAGER_CONST_THRESHOLD = 128;
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
{ {
if(tensor is EagerTensor)
if(tensor is EagerTensor or NDArray)
{ {
if (name == null) if (name == null)
name = ops.uid().ToString(); name = ops.uid().ToString();


+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs View File

@@ -10,4 +10,5 @@ public interface IOptimizer
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null, string name = null,
bool experimental_aggregate_gradients = true); bool experimental_aggregate_gradients = true);
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
} }

+ 5
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -216,10 +216,12 @@ namespace Tensorflow
public virtual object get_attr(string name) public virtual object get_attr(string name)
{ {
var buf = new Buffer(); var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true);
Status status = new();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
var tf_buffer = c_api.TF_GetBuffer(buf);


var x = AttrValue.Parser.ParseFrom(buf.ToArray());
var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>());


var oneof_value = x.ValueCase; var oneof_value = x.ValueCase;
if (oneof_value == AttrValue.ValueOneofCase.None) if (oneof_value == AttrValue.ValueOneofCase.None)


+ 40
- 8
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -64,36 +64,68 @@ namespace Tensorflow
var num_elements = shape.size; var num_elements = shape.size;
var tensor_dtype = tensor.Dtype.as_tf_dtype(); var tensor_dtype = tensor.Dtype.as_tf_dtype();


T[] ExpandArrayToSize<T>(IList<T> src)
{
if(src.Count == 0)
{
return new T[0];
}
var pad_count = num_elements - src.Count;
var pre = pad_count / 2;
var after = pad_count - pre;
var first_elem = src[0];
var last_elem = src[src.Count - 1];
T[] res = new T[num_elements];
for(long i = 0; i < num_elements; i++)
{
if (i < pre) res[i] = first_elem;
else if (i >= num_elements - after) res[i] = last_elem;
else res[i] = src[(int)(i - pre)];
}
return res;
}

if (shape.ndim > 0 && tensor.TensorContent.Length > 0) if (shape.ndim > 0 && tensor.TensorContent.Length > 0)
{ {
return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype);
} }
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
NDArray values;
if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
{ {
return np.array(tensor.HalfVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.HalfVal));
} }
else if (tensor.Dtype == DataType.DtFloat) else if (tensor.Dtype == DataType.DtFloat)
{ {
return np.array(tensor.FloatVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.FloatVal));
} }
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
{ {
return np.array(tensor.IntVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.IntVal));
} }
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype))
{ {
return np.array(tensor.Int64Val.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.Int64Val));
} }
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype))
{ {
return np.array(tensor.Uint64Val.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.Uint64Val));
} }
else if (tensor.Dtype == DataType.DtBool) else if (tensor.Dtype == DataType.DtBool)
{ {
return np.array(tensor.BoolVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.BoolVal));
}
else
{
throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " +
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
}

if(values.size == 0)
{
return np.zeros(shape, tensor_dtype);
} }


throw new NotImplementedException("MakeNdarray");
return values.reshape(shape);
} }


private static readonly TF_DataType[] quantized_types = new TF_DataType[] private static readonly TF_DataType[] quantized_types = new TF_DataType[]


+ 15
- 3
src/TensorFlowNET.Core/Trackables/TrackableConstant.cs View File

@@ -1,5 +1,6 @@
using Google.Protobuf.Collections; using Google.Protobuf.Collections;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding;


namespace Tensorflow.Trackables; namespace Tensorflow.Trackables;


@@ -11,12 +12,23 @@ public class TrackableConstant : Trackable
_constant = constant; _constant = constant;
} }


public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
Dictionary<string, MapField<string, AttrValue>> operation_attributes) Dictionary<string, MapField<string, AttrValue>> operation_attributes)
{ {
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor;
var ndarray = tensor_util.MakeNdarray(tensor_proto); var ndarray = tensor_util.MakeNdarray(tensor_proto);
var imported_constant = constant_op.constant(ndarray);
return (new TrackableConstant(imported_constant), null);
Tensor imported_constant;
if (tensor_proto.Dtype == DataType.DtString)
{
imported_constant = tf_with(ops.device("CPU"), _ =>
{
return constant_op.constant(ndarray);
});
}
else
{
imported_constant = constant_op.constant(ndarray);
}
return (imported_constant, null);
} }
} }

+ 5
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -46,4 +46,9 @@ public class RevivedTypes
return (null, null); return (null, null);
} }
} }

public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj)
{
_registered_revived_creator[identifier] = obj;
}
} }

+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -137,7 +137,7 @@ public class SaveableView
/// </summary> /// </summary>
public List<int> dependency_sorted_node_ids() public List<int> dependency_sorted_node_ids()
{ {
Dictionary<int, IEnumerable<int>> dependency_map = new();
Dictionary<int, List<int>> dependency_map = new();
foreach (var node in _nodes) foreach (var node in _nodes)
{ {
var node_id = _node_ids[node]; var node_id = _node_ids[node];


+ 41
- 21
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -116,17 +116,23 @@ namespace Tensorflow.Training.Saving.SavedModel
} }


Dictionary<string, ConcreteFunction> loaded_gradients = new(); Dictionary<string, ConcreteFunction> loaded_gradients = new();
foreach (var fdef in _sort_function_defs(library, function_deps))
// Debug(Rinne)
var temp = _sort_function_defs(library, function_deps);
int i = 0;
foreach (var fdef in temp)
{ {
i++;
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);


object structured_input_signature = null; object structured_input_signature = null;
object structured_outputs = null; object structured_outputs = null;
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
{ {
var proto = saved_object_graph.ConcreteFunctions[orig_name];
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
// TODO(Rinne): deal with structured_input_signature and structured_outputs.

//var proto = saved_object_graph.ConcreteFunctions[orig_name];
//structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
//structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
} }


graph.as_default(); graph.as_default();
@@ -234,27 +240,41 @@ namespace Tensorflow.Training.Saving.SavedModel


private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients)
{ {
foreach(var op in func_graph.get_operations())
if(loaded_gradients is null || loaded_gradients.Count == 0)
{ {
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch(InvalidArgumentError)
foreach (var op in func_graph.get_operations())
{ {
continue;
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
} }
if (loaded_gradients.ContainsKey(gradient_op_type))
}
else
{
foreach (var op in func_graph.get_operations())
{ {
var grad_fn = loaded_gradients[gradient_op_type];
grad_fn.NumPositionArgs = op.op.inputs.Length;
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch (InvalidArgumentError)
{
continue;
}
if (loaded_gradients.ContainsKey(gradient_op_type))
{
var grad_fn = loaded_gradients[gradient_op_type];
grad_fn.NumPositionArgs = op.op.inputs.Length;
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
}
} }
} }
} }


+ 45
- 22
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -15,6 +15,7 @@ using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Trackables; using Tensorflow.Trackables;
using OneOf; using OneOf;
using Tensorflow.Keras.Engine;


namespace Tensorflow namespace Tensorflow
{ {
@@ -34,7 +35,7 @@ namespace Tensorflow
private List<int>? _filtered_nodes; private List<int>? _filtered_nodes;
private List<int> _ordered_node_ids; private List<int> _ordered_node_ids;
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes;
private List<Trackable> _nodes;
private List<object> _nodes;
private Dictionary<int, Action<object, object, object>> _node_setters; private Dictionary<int, Action<object, object, object>> _node_setters;
private Dictionary<string, ConcreteFunction> _concrete_functions; private Dictionary<string, ConcreteFunction> _concrete_functions;
private HashSet<string> _restored_concrete_functions; private HashSet<string> _restored_concrete_functions;
@@ -213,7 +214,13 @@ namespace Tensorflow
continue; continue;
} }
var proto = _proto.Nodes[node_id]; var proto = _proto.Nodes[node_id];
foreach(var dep in _get_node_dependencies(proto).Values.Distinct())
if(node_id == 10522)
{
// Debug(Rinne)
Console.WriteLine();
}
var temp = _get_node_dependencies(proto);
foreach (var dep in _get_node_dependencies(proto).Values.Distinct())
{ {
deps.Add(dep); deps.Add(dep);
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep))
@@ -232,7 +239,7 @@ namespace Tensorflow
// The optimizer and original variable must be created before the slot // The optimizer and original variable must be created before the slot
// variable, since the slot variable is generated using the Optimizer's // variable, since the slot variable is generated using the Optimizer's
// add_slot API. // add_slot API.
var slot_deps = dependency_map[slot_variable_node_id];
var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List<int>());
slot_deps.Add(node_id); slot_deps.Add(node_id);
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); slot_deps.Add(slot_variable_proto.OriginalVariableNodeId);


@@ -245,7 +252,12 @@ namespace Tensorflow
} }
try try
{ {
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>));
int total = 0;
foreach(var v in dependency_map.Values)
{
total += v.Count;
}
return TrackableUtils.order_by_dependency(dependency_map);
} }
catch (TrackableUtils.CyclicDependencyError ex) catch (TrackableUtils.CyclicDependencyError ex)
{ {
@@ -339,9 +351,20 @@ namespace Tensorflow
var saveable_object_proto = item.Value; var saveable_object_proto = item.Value;
var save_fn_id = saveable_object_proto.SaveFunction; var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction; var restore_fn_id = saveable_object_proto.RestoreFunction;
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id));
saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id));
}
var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
if (saveable_objects is not null && saveable_objects.Count > 0)
{
if(node is Trackable trackable)
{
trackable.SelfSaveableObjectFactories = saveable_objects;
}
else
{
throw new TypeError();
}
} }
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
} }
} }
} }
@@ -379,12 +402,12 @@ namespace Tensorflow
{ {
// Use the public Optimizer interface when creating slot variables. // Use the public Optimizer interface when creating slot variables.
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id];
var optimizer_object = nodes[optimizer_node_id];
var optimizer_object = nodes[optimizer_node_id] as IOptimizer;
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];


// TODO(Rinne): implement it.
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName);
nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable;
node_setters[slot_variable_proto.SlotVariableNodeId] = setattr;
} }
else else
{ {
@@ -398,7 +421,7 @@ namespace Tensorflow
{ {
nodes[0] = _recreate_base_user_object().Item1; nodes[0] = _recreate_base_user_object().Item1;
} }
_nodes = new List<Trackable>();
_nodes = new List<object>();
for(int i = 0; i < _proto.Nodes.Count; i++) for(int i = 0; i < _proto.Nodes.Count; i++)
{ {
_nodes.Add(nodes[i]); _nodes.Add(nodes[i]);
@@ -412,7 +435,7 @@ namespace Tensorflow
private void _restore_checkpoint() private void _restore_checkpoint()
{ {
var variables_path = SavedModelUtils.get_variables_path(_export_dir); var variables_path = SavedModelUtils.get_variables_path(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0)));
var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0)));
tf_with(ops.device("CPU"), _ => tf_with(ops.device("CPU"), _ =>
{ {
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
@@ -467,7 +490,7 @@ namespace Tensorflow
} }
} }


private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes)
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, object> nodes)
{ {
if (_restored_concrete_functions.Contains(concrete_function_name)) if (_restored_concrete_functions.Contains(concrete_function_name))
{ {
@@ -485,12 +508,12 @@ namespace Tensorflow
// TODO: implement it with concrete functions. // TODO: implement it with concrete functions.
} }


public Trackable get(int node_id)
public object get(int node_id)
{ {
return _nodes[node_id]; return _nodes[node_id];
} }


public Trackable get(string node_id)
public object get(string node_id)
{ {
return get(_node_path_to_id[node_id]); return get(_node_path_to_id[node_id]);
} }
@@ -512,9 +535,9 @@ namespace Tensorflow
} }
} }


private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
private (Dictionary<int, object>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
{ {
Dictionary<int, Trackable> nodes = new();
Dictionary<int, object> nodes = new();
Dictionary<int, Action<object, object, object>> node_setters = new(); Dictionary<int, Action<object, object, object>> node_setters = new();
foreach(var item in _loaded_nodes) foreach(var item in _loaded_nodes)
{ {
@@ -534,10 +557,10 @@ namespace Tensorflow
} }
} }


private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
private (object, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, object> nodes)
{ {
// skip the registered classes. // skip the registered classes.
Dictionary<OneOf<string, int>, Trackable> dependencies = new();
Dictionary<OneOf<string, int>, object> dependencies = new();
foreach(var item in _get_node_dependencies(proto)) foreach(var item in _get_node_dependencies(proto))
{ {
dependencies[item.Key] = nodes[item.Value]; dependencies[item.Key] = nodes[item.Value];
@@ -558,7 +581,7 @@ namespace Tensorflow
/// <param name="proto"></param> /// <param name="proto"></param>
/// <param name="node_id"></param> /// <param name="node_id"></param>
/// <param name="dependencies"></param> /// <param name="dependencies"></param>
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies)
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, object> dependencies)
{ {
return proto.KindCase switch return proto.KindCase switch
{ {
@@ -626,7 +649,7 @@ namespace Tensorflow
} }


private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto,
IDictionary<OneOf<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, object> dependencies)
{ {
var fn = function_deserialization.recreate_function(proto, _concrete_functions); var fn = function_deserialization.recreate_function(proto, _concrete_functions);
foreach (var name in proto.ConcreteFunctions) foreach (var name in proto.ConcreteFunctions)
@@ -637,7 +660,7 @@ namespace Tensorflow
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
IDictionary<OneOf<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, object> dependencies)
{ {
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions);
_setup_function_captures(proto.ConcreteFunctionName, dependencies); _setup_function_captures(proto.ConcreteFunctionName, dependencies);


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow
tf_with(ops.init_scope(), x => tf_with(ops.init_scope(), x =>
{ {
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters);
root = loader.get(0);
root = (Trackable)loader.get(0);
// skip the assignment of `graph_debug_info`. // skip the assignment of `graph_debug_info`.
}); });
// skip the assignment of `tensorflow_version` // skip the assignment of `tensorflow_version`
@@ -99,7 +99,7 @@ namespace Tensorflow
} }
if(filters != null && filters.Count > 0) if(filters != null && filters.Count > 0)
{ {
return filters.Keys.ToDictionary(x => x, x => loader.get(x));
return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x));
} }
else else
{ {


+ 2
- 2
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -52,7 +52,7 @@ public static class TrackableUtils
/// </summary> /// </summary>
/// <param name="dependency_map"></param> /// <param name="dependency_map"></param>
/// <exception cref="ValueError"></exception> /// <exception cref="ValueError"></exception>
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map)
public static List<int> order_by_dependency(IDictionary<int, List<int>> dependency_map)
{ {
Dictionary<int, HashSet<int>> reverse_dependency_map = new(); Dictionary<int, HashSet<int>> reverse_dependency_map = new();
foreach (var pair in dependency_map) foreach (var pair in dependency_map)
@@ -102,7 +102,7 @@ public static class TrackableUtils
edges.Remove(x); edges.Remove(x);
if (edges.Count == 0) if (edges.Count == 0)
{ {
to_visit.Enqueue(dep);
to_visit.Enqueue(dep);
if (!reverse_dependency_map.Remove(dep)) if (!reverse_dependency_map.Remove(dep))
{ {
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map");


+ 18
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -333,5 +333,23 @@ namespace Tensorflow
}); });
return array_ops.identity(value); return array_ops.identity(value);
} }

//public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value();
//public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value();

//public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value();
//public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y;
//public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y;

//public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y;

//public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y;
} }
} }

+ 3
- 16
src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs View File

@@ -1,19 +1,6 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy; using Tensorflow.NumPy;


namespace Tensorflow namespace Tensorflow


+ 6
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -169,6 +169,12 @@ namespace Tensorflow.Keras
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
} }


public void set_value(IVariableV1 x, object value)
{
// TODO(Rinne): check the implementation.
x.assign(value);
}

public void batch_set_value(List<(IVariableV1, NDArray)> tuples) public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
{ {
if (ops.executing_eagerly_outside_functions()) if (ops.executing_eagerly_outside_functions())


+ 5
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -36,6 +36,11 @@ namespace Tensorflow.Keras
} }
} }


static KerasInterface()
{
RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer());
}

public KerasDataset datasets { get; } = new KerasDataset(); public KerasDataset datasets { get; } = new KerasDataset();
public IInitializersApi initializers { get; } = new InitializersApi(); public IInitializersApi initializers { get; } = new InitializersApi();
public Regularizers regularizers { get; } = new Regularizers(); public Regularizers regularizers { get; } = new Regularizers();


+ 4
- 4
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -14,11 +14,11 @@ namespace Tensorflow.Keras.Optimizers
protected bool _hypers_created; protected bool _hypers_created;
protected virtual string _name { get; } protected virtual string _name { get; }


IVariableV1 _iterations;
protected IVariableV1 _iterations;
protected ResourceVariable iterations => _iterations as ResourceVariable; protected ResourceVariable iterations => _iterations as ResourceVariable;
List<IVariableV1> _weights; List<IVariableV1> _weights;
Dictionary<string, float> _hyper;
Dictionary<string, IVariableV1> _hyper_variables;
protected Dictionary<string, float> _hyper;
protected Dictionary<string, IVariableV1> _hyper_variables;
protected bool _momentum; protected bool _momentum;
protected float _initial_decay = 0.0f; protected float _initial_decay = 0.0f;
protected bool _use_locking = true; protected bool _use_locking = true;
@@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Optimizers
} }
} }


protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null)
public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null)
{ {
if (initializer == null) if (initializer == null)
initializer = tf.zeros_initializer; initializer = tf.zeros_initializer;


+ 63
- 0
src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs View File

@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Saving;
using Tensorflow.Train;
using Tensorflow.Training;

namespace Tensorflow.Keras.Optimizers
{
public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig
{
public String Identifier { get; } = "optimizer";
public int Version { get; } = 2;
public int MinConsumerVersion { get; } = 1;
public int MinProducerVersion { get; } = 1;
public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" })
{
_hypers_created = true;
}

public IKerasConfig get_config()
{
throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " +
"supported. Please file a feature request if this limitation bothers you.");
}

public void SetValue(object name, object value)
{
if(name is not String str)
{
throw new TypeError($"The name of value to set must be string, but got {name.GetType()}");
}
if(value is Trackable trackable)
{
_track_trackable(trackable, str, overwrite: true);
}
if(value is IVariableV1 resource_variable)
{
if (!_hyper_variables.ContainsKey(str))
{
_hyper_variables[str] = resource_variable;
}
else
{
keras.backend.set_value(resource_variable, value);
}
}
else if (value is float f)
{
_hyper[str] = f;
}
else
{
throw new NotImplementedException();
}
}
public Trackable FromProto(SavedUserObject proto)
{
return new RestoredOptimizer();
}
}
}

+ 10
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -2,6 +2,7 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow; using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy; using Tensorflow.NumPy;
@@ -103,4 +104,13 @@ public class SequentialModelLoad


classify_model.fit(x, y, batch_size: 4); classify_model.fit(x, y, batch_size: 4);
} }

[Ignore]
[TestMethod]
public void TestModelBeforeTF2_5()
{
var a = keras.layers;
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model;
model.summary();
}
} }

Loading…
Cancel
Save