Fix the error of loading model saved before tf2.5.tags/v0.100.5-BERT-load
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Train; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public SavedModelAPI saved_model { get; } = new SavedModelAPI(); | |||
} | |||
public class SavedModelAPI | |||
{ | |||
public Trackable load(string export_dir, LoadOptions? options = null) | |||
{ | |||
return Loader.load(export_dir, options); | |||
} | |||
} | |||
} |
@@ -8,6 +8,7 @@ using Tensorflow.Exceptions; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Functions; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -181,7 +182,7 @@ public class FuncGraph : Graph, IDisposable | |||
const int _EAGER_CONST_THRESHOLD = 128; | |||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
{ | |||
if(tensor is EagerTensor) | |||
if(tensor is EagerTensor or NDArray) | |||
{ | |||
if (name == null) | |||
name = ops.uid().ToString(); | |||
@@ -10,4 +10,5 @@ public interface IOptimizer | |||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true); | |||
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); | |||
} |
@@ -216,10 +216,12 @@ namespace Tensorflow | |||
public virtual object get_attr(string name) | |||
{ | |||
var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
tf.Status.Check(true); | |||
Status status = new(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
status.Check(true); | |||
var tf_buffer = c_api.TF_GetBuffer(buf); | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>()); | |||
var oneof_value = x.ValueCase; | |||
if (oneof_value == AttrValue.ValueOneofCase.None) | |||
@@ -64,36 +64,68 @@ namespace Tensorflow | |||
var num_elements = shape.size; | |||
var tensor_dtype = tensor.Dtype.as_tf_dtype(); | |||
T[] ExpandArrayToSize<T>(IList<T> src) | |||
{ | |||
if(src.Count == 0) | |||
{ | |||
return new T[0]; | |||
} | |||
var pad_count = num_elements - src.Count; | |||
var pre = pad_count / 2; | |||
var after = pad_count - pre; | |||
var first_elem = src[0]; | |||
var last_elem = src[src.Count - 1]; | |||
T[] res = new T[num_elements]; | |||
for(long i = 0; i < num_elements; i++) | |||
{ | |||
if (i < pre) res[i] = first_elem; | |||
else if (i >= num_elements - after) res[i] = last_elem; | |||
else res[i] = src[(int)(i - pre)]; | |||
} | |||
return res; | |||
} | |||
if (shape.ndim > 0 && tensor.TensorContent.Length > 0) | |||
{ | |||
return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); | |||
} | |||
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||
NDArray values; | |||
if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||
{ | |||
return np.array(tensor.HalfVal.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.HalfVal)); | |||
} | |||
else if (tensor.Dtype == DataType.DtFloat) | |||
{ | |||
return np.array(tensor.FloatVal.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.FloatVal)); | |||
} | |||
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) | |||
{ | |||
return np.array(tensor.IntVal.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.IntVal)); | |||
} | |||
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | |||
{ | |||
return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.Int64Val)); | |||
} | |||
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | |||
{ | |||
return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.Uint64Val)); | |||
} | |||
else if (tensor.Dtype == DataType.DtBool) | |||
{ | |||
return np.array(tensor.BoolVal.ToArray()).reshape(shape); | |||
values = np.array(ExpandArrayToSize(tensor.BoolVal)); | |||
} | |||
else | |||
{ | |||
throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " + | |||
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); | |||
} | |||
if(values.size == 0) | |||
{ | |||
return np.zeros(shape, tensor_dtype); | |||
} | |||
throw new NotImplementedException("MakeNdarray"); | |||
return values.reshape(shape); | |||
} | |||
private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||
@@ -1,5 +1,6 @@ | |||
using Google.Protobuf.Collections; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Trackables; | |||
@@ -11,12 +12,23 @@ public class TrackableConstant : Trackable | |||
_constant = constant; | |||
} | |||
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||
{ | |||
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | |||
var ndarray = tensor_util.MakeNdarray(tensor_proto); | |||
var imported_constant = constant_op.constant(ndarray); | |||
return (new TrackableConstant(imported_constant), null); | |||
Tensor imported_constant; | |||
if (tensor_proto.Dtype == DataType.DtString) | |||
{ | |||
imported_constant = tf_with(ops.device("CPU"), _ => | |||
{ | |||
return constant_op.constant(ndarray); | |||
}); | |||
} | |||
else | |||
{ | |||
imported_constant = constant_op.constant(ndarray); | |||
} | |||
return (imported_constant, null); | |||
} | |||
} |
@@ -46,4 +46,9 @@ public class RevivedTypes | |||
return (null, null); | |||
} | |||
} | |||
public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj) | |||
{ | |||
_registered_revived_creator[identifier] = obj; | |||
} | |||
} |
@@ -137,7 +137,7 @@ public class SaveableView | |||
/// </summary> | |||
public List<int> dependency_sorted_node_ids() | |||
{ | |||
Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||
Dictionary<int, List<int>> dependency_map = new(); | |||
foreach (var node in _nodes) | |||
{ | |||
var node_id = _node_ids[node]; | |||
@@ -116,17 +116,23 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
} | |||
Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
foreach (var fdef in _sort_function_defs(library, function_deps)) | |||
// Debug(Rinne) | |||
var temp = _sort_function_defs(library, function_deps); | |||
int i = 0; | |||
foreach (var fdef in temp) | |||
{ | |||
i++; | |||
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
object structured_input_signature = null; | |||
object structured_outputs = null; | |||
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
{ | |||
var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||
// TODO(Rinne): deal with structured_input_signature and structured_outputs. | |||
//var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
//structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||
//structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||
} | |||
graph.as_default(); | |||
@@ -234,27 +240,41 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||
{ | |||
foreach(var op in func_graph.get_operations()) | |||
if(loaded_gradients is null || loaded_gradients.Count == 0) | |||
{ | |||
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
{ | |||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
op.op._gradient_function = function._get_gradient_function(); | |||
} | |||
string gradient_op_type = null; | |||
try | |||
{ | |||
gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
} | |||
catch(InvalidArgumentError) | |||
foreach (var op in func_graph.get_operations()) | |||
{ | |||
continue; | |||
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
{ | |||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
op.op._gradient_function = function._get_gradient_function(); | |||
} | |||
} | |||
if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
} | |||
else | |||
{ | |||
foreach (var op in func_graph.get_operations()) | |||
{ | |||
var grad_fn = loaded_gradients[gradient_op_type]; | |||
grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
{ | |||
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
op.op._gradient_function = function._get_gradient_function(); | |||
} | |||
string gradient_op_type = null; | |||
try | |||
{ | |||
gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
} | |||
catch (InvalidArgumentError) | |||
{ | |||
continue; | |||
} | |||
if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
{ | |||
var grad_fn = loaded_gradients[gradient_op_type]; | |||
grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
} | |||
} | |||
} | |||
} | |||
@@ -15,6 +15,7 @@ using Tensorflow.Functions; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Trackables; | |||
using OneOf; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow | |||
{ | |||
@@ -34,7 +35,7 @@ namespace Tensorflow | |||
private List<int>? _filtered_nodes; | |||
private List<int> _ordered_node_ids; | |||
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
private List<Trackable> _nodes; | |||
private List<object> _nodes; | |||
private Dictionary<int, Action<object, object, object>> _node_setters; | |||
private Dictionary<string, ConcreteFunction> _concrete_functions; | |||
private HashSet<string> _restored_concrete_functions; | |||
@@ -213,7 +214,13 @@ namespace Tensorflow | |||
continue; | |||
} | |||
var proto = _proto.Nodes[node_id]; | |||
foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
if(node_id == 10522) | |||
{ | |||
// Debug(Rinne) | |||
Console.WriteLine(); | |||
} | |||
var temp = _get_node_dependencies(proto); | |||
foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
{ | |||
deps.Add(dep); | |||
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||
@@ -232,7 +239,7 @@ namespace Tensorflow | |||
// The optimizer and original variable must be created before the slot | |||
// variable, since the slot variable is generated using the Optimizer's | |||
// add_slot API. | |||
var slot_deps = dependency_map[slot_variable_node_id]; | |||
var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List<int>()); | |||
slot_deps.Add(node_id); | |||
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||
@@ -245,7 +252,12 @@ namespace Tensorflow | |||
} | |||
try | |||
{ | |||
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||
int total = 0; | |||
foreach(var v in dependency_map.Values) | |||
{ | |||
total += v.Count; | |||
} | |||
return TrackableUtils.order_by_dependency(dependency_map); | |||
} | |||
catch (TrackableUtils.CyclicDependencyError ex) | |||
{ | |||
@@ -339,9 +351,20 @@ namespace Tensorflow | |||
var saveable_object_proto = item.Value; | |||
var save_fn_id = saveable_object_proto.SaveFunction; | |||
var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||
saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id)); | |||
} | |||
var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
if (saveable_objects is not null && saveable_objects.Count > 0) | |||
{ | |||
if(node is Trackable trackable) | |||
{ | |||
trackable.SelfSaveableObjectFactories = saveable_objects; | |||
} | |||
else | |||
{ | |||
throw new TypeError(); | |||
} | |||
} | |||
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
} | |||
} | |||
} | |||
@@ -379,12 +402,12 @@ namespace Tensorflow | |||
{ | |||
// Use the public Optimizer interface when creating slot variables. | |||
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||
var optimizer_object = nodes[optimizer_node_id]; | |||
var optimizer_object = nodes[optimizer_node_id] as IOptimizer; | |||
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName); | |||
nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable; | |||
node_setters[slot_variable_proto.SlotVariableNodeId] = setattr; | |||
} | |||
else | |||
{ | |||
@@ -398,7 +421,7 @@ namespace Tensorflow | |||
{ | |||
nodes[0] = _recreate_base_user_object().Item1; | |||
} | |||
_nodes = new List<Trackable>(); | |||
_nodes = new List<object>(); | |||
for(int i = 0; i < _proto.Nodes.Count; i++) | |||
{ | |||
_nodes.Add(nodes[i]); | |||
@@ -412,7 +435,7 @@ namespace Tensorflow | |||
private void _restore_checkpoint() | |||
{ | |||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0))); | |||
tf_with(ops.device("CPU"), _ => | |||
{ | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
@@ -467,7 +490,7 @@ namespace Tensorflow | |||
} | |||
} | |||
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes) | |||
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, object> nodes) | |||
{ | |||
if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
{ | |||
@@ -485,12 +508,12 @@ namespace Tensorflow | |||
// TODO: implement it with concrete functions. | |||
} | |||
public Trackable get(int node_id) | |||
public object get(int node_id) | |||
{ | |||
return _nodes[node_id]; | |||
} | |||
public Trackable get(string node_id) | |||
public object get(string node_id) | |||
{ | |||
return get(_node_path_to_id[node_id]); | |||
} | |||
@@ -512,9 +535,9 @@ namespace Tensorflow | |||
} | |||
} | |||
private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
private (Dictionary<int, object>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
{ | |||
Dictionary<int, Trackable> nodes = new(); | |||
Dictionary<int, object> nodes = new(); | |||
Dictionary<int, Action<object, object, object>> node_setters = new(); | |||
foreach(var item in _loaded_nodes) | |||
{ | |||
@@ -534,10 +557,10 @@ namespace Tensorflow | |||
} | |||
} | |||
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
private (object, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, object> nodes) | |||
{ | |||
// skip the registered classes. | |||
Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||
Dictionary<OneOf<string, int>, object> dependencies = new(); | |||
foreach(var item in _get_node_dependencies(proto)) | |||
{ | |||
dependencies[item.Key] = nodes[item.Value]; | |||
@@ -558,7 +581,7 @@ namespace Tensorflow | |||
/// <param name="proto"></param> | |||
/// <param name="node_id"></param> | |||
/// <param name="dependencies"></param> | |||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, object> dependencies) | |||
{ | |||
return proto.KindCase switch | |||
{ | |||
@@ -626,7 +649,7 @@ namespace Tensorflow | |||
} | |||
private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
IDictionary<OneOf<string, int>, object> dependencies) | |||
{ | |||
var fn = function_deserialization.recreate_function(proto, _concrete_functions); | |||
foreach (var name in proto.ConcreteFunctions) | |||
@@ -637,7 +660,7 @@ namespace Tensorflow | |||
} | |||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
IDictionary<OneOf<string, int>, object> dependencies) | |||
{ | |||
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
_setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||
tf_with(ops.init_scope(), x => | |||
{ | |||
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||
root = loader.get(0); | |||
root = (Trackable)loader.get(0); | |||
// skip the assignment of `graph_debug_info`. | |||
}); | |||
// skip the assignment of `tensorflow_version` | |||
@@ -99,7 +99,7 @@ namespace Tensorflow | |||
} | |||
if(filters != null && filters.Count > 0) | |||
{ | |||
return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||
return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x)); | |||
} | |||
else | |||
{ | |||
@@ -52,7 +52,7 @@ public static class TrackableUtils | |||
/// </summary> | |||
/// <param name="dependency_map"></param> | |||
/// <exception cref="ValueError"></exception> | |||
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||
public static List<int> order_by_dependency(IDictionary<int, List<int>> dependency_map) | |||
{ | |||
Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | |||
foreach (var pair in dependency_map) | |||
@@ -102,7 +102,7 @@ public static class TrackableUtils | |||
edges.Remove(x); | |||
if (edges.Count == 0) | |||
{ | |||
to_visit.Enqueue(dep); | |||
to_visit.Enqueue(dep); | |||
if (!reverse_dependency_map.Remove(dep)) | |||
{ | |||
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | |||
@@ -333,5 +333,23 @@ namespace Tensorflow | |||
}); | |||
return array_ops.identity(value); | |||
} | |||
//public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; | |||
//public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; | |||
//public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; | |||
//public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); | |||
//public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; | |||
//public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; | |||
//public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; | |||
//public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; | |||
//public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); | |||
//public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); | |||
//public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; | |||
//public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; | |||
//public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; | |||
//public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; | |||
} | |||
} |
@@ -1,19 +1,6 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow | |||
@@ -169,6 +169,12 @@ namespace Tensorflow.Keras | |||
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||
} | |||
public void set_value(IVariableV1 x, object value) | |||
{ | |||
// TODO(Rinne): check the implementation. | |||
x.assign(value); | |||
} | |||
public void batch_set_value(List<(IVariableV1, NDArray)> tuples) | |||
{ | |||
if (ops.executing_eagerly_outside_functions()) | |||
@@ -36,6 +36,11 @@ namespace Tensorflow.Keras | |||
} | |||
} | |||
static KerasInterface() | |||
{ | |||
RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer()); | |||
} | |||
public KerasDataset datasets { get; } = new KerasDataset(); | |||
public IInitializersApi initializers { get; } = new InitializersApi(); | |||
public Regularizers regularizers { get; } = new Regularizers(); | |||
@@ -14,11 +14,11 @@ namespace Tensorflow.Keras.Optimizers | |||
protected bool _hypers_created; | |||
protected virtual string _name { get; } | |||
IVariableV1 _iterations; | |||
protected IVariableV1 _iterations; | |||
protected ResourceVariable iterations => _iterations as ResourceVariable; | |||
List<IVariableV1> _weights; | |||
Dictionary<string, float> _hyper; | |||
Dictionary<string, IVariableV1> _hyper_variables; | |||
protected Dictionary<string, float> _hyper; | |||
protected Dictionary<string, IVariableV1> _hyper_variables; | |||
protected bool _momentum; | |||
protected float _initial_decay = 0.0f; | |||
protected bool _use_locking = true; | |||
@@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Optimizers | |||
} | |||
} | |||
protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||
public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||
{ | |||
if (initializer == null) | |||
initializer = tf.zeros_initializer; | |||
@@ -0,0 +1,63 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig | |||
{ | |||
public String Identifier { get; } = "optimizer"; | |||
public int Version { get; } = 2; | |||
public int MinConsumerVersion { get; } = 1; | |||
public int MinProducerVersion { get; } = 1; | |||
public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" }) | |||
{ | |||
_hypers_created = true; | |||
} | |||
public IKerasConfig get_config() | |||
{ | |||
throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " + | |||
"supported. Please file a feature request if this limitation bothers you."); | |||
} | |||
public void SetValue(object name, object value) | |||
{ | |||
if(name is not String str) | |||
{ | |||
throw new TypeError($"The name of value to set must be string, but got {name.GetType()}"); | |||
} | |||
if(value is Trackable trackable) | |||
{ | |||
_track_trackable(trackable, str, overwrite: true); | |||
} | |||
if(value is IVariableV1 resource_variable) | |||
{ | |||
if (!_hyper_variables.ContainsKey(str)) | |||
{ | |||
_hyper_variables[str] = resource_variable; | |||
} | |||
else | |||
{ | |||
keras.backend.set_value(resource_variable, value); | |||
} | |||
} | |||
else if (value is float f) | |||
{ | |||
_hyper[str] = f; | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
public Trackable FromProto(SavedUserObject proto) | |||
{ | |||
return new RestoredOptimizer(); | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.UnitTest.Helpers; | |||
using Tensorflow.NumPy; | |||
@@ -103,4 +104,13 @@ public class SequentialModelLoad | |||
classify_model.fit(x, y, batch_size: 4); | |||
} | |||
[Ignore] | |||
[TestMethod] | |||
public void TestModelBeforeTF2_5() | |||
{ | |||
var a = keras.layers; | |||
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model; | |||
model.summary(); | |||
} | |||
} |