Browse Source

Support construct graph from proto.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
355ca3ab6c
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
30 changed files with 1216 additions and 25 deletions
  1. +18
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.io.cs
  3. +6
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  4. +288
    -0
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  5. +102
    -8
      src/TensorFlowNET.Core/Framework/importer.cs
  6. +12
    -0
      src/TensorFlowNET.Core/Framework/versions.cs
  7. +10
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  8. +1
    -0
      src/TensorFlowNET.Core/Functions/Function.cs
  9. +12
    -0
      src/TensorFlowNET.Core/Functions/IGenericFunction.cs
  10. +88
    -0
      src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs
  11. +14
    -0
      src/TensorFlowNET.Core/Gradients/custom_gradient.cs
  12. +3
    -1
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  13. +5
    -0
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  14. +6
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  15. +2
    -0
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  16. +4
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  17. +28
    -0
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  19. +12
    -3
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  20. +344
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  21. +19
    -5
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  22. +14
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs
  23. +5
    -0
      src/TensorFlowNET.Core/ops.cs
  24. +13
    -0
      src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
  25. +18
    -4
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  26. +62
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs
  27. +37
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs
  28. +73
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
  29. +9
    -0
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  30. +9
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 18
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System.Text;

namespace Tensorflow
@@ -45,6 +46,23 @@ namespace Tensorflow
{
return as_text(bytes_or_text, encoding);
}

public ByteString as_bytes(ByteString bytes, Encoding encoding = null)
{
return bytes;
}
public ByteString as_bytes(byte[] bytes, Encoding encoding = null)
{
return ByteString.CopyFrom(bytes);
}
public ByteString as_bytes(string text, Encoding encoding = null)
{
if(encoding is null)
{
encoding = Encoding.UTF8;
}
return ByteString.CopyFrom(encoding.GetBytes(text));
}
}

public bool executing_eagerly()


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -54,6 +54,6 @@ namespace Tensorflow
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
string name = null,
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list);
}
}

+ 6
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -156,6 +156,12 @@ namespace Tensorflow.Contexts
return has_graph_arg;
}

public bool has_function(string name)
{
ensure_initialized();
return c_api.TFE_ContextHasFunction(_handle, name);
}

public void restore_mode()
{
context_switches.Pop();


+ 288
- 0
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -0,0 +1,288 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow.Framework
{
public class function_def_lib
{
// TODO(Rinne): process signatures and structured outputs.
public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature,
object? structured_outputs, List<TensorShapeProto> input_shapes = null)
{
var func_graph = new FuncGraph(fdef.Signature.Name);
if(input_shapes is null)
{
if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr))
{
var raw_input_shapes = input_shapes_attr.List.Shape;
input_shapes = new List<TensorShapeProto>();
foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y)))
{
if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0)
{
input_shapes.Add(null);
}
else
{
input_shapes.Add(input_shape);
}
}
}
}

var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes);

func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);

foreach(var node in graph_def.Node)
{
if(node.Attr.TryGetValue("_output_shapes", out var output_shapes))
{
var op = func_graph.get_operation_by_name(node.Name);
foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length)))
{
op.outputs[output_index].shape = new Shape(shape);
}
}
}
Dictionary<long, string> output_names = new();
foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names))
{
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name;
}
// TODO(Rinne): func_graph._output_names = output_names

func_graph.Exit();
return func_graph;
}

public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes)
{
var graph_def = new GraphDef()
{
Versions = new VersionDef()
{
Producer = versions.GRAPH_DEF_VERSION,
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
}
};

var default_graph = ops.get_default_graph();

if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count)
{
throw new ValueError($"Length of `input_shapes` must match the number " +
$"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " +
$"{fdef.Signature.InputArg.Count} `input_arg`s.");
}

foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg))
{
NodeDef node_def = new();
node_def.Name = arg_def.Name;
node_def.Op = "Placeholder";
node_def.Attr["dtype"] = new AttrValue()
{
Type = arg_def.Type
};
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null)
{
var input_shape = input_shapes[i];
// skip the condition that input_shape is not `TensorShapeProto`.
AttrValue shape = new AttrValue()
{
Shape = new TensorShapeProto()
};
shape.Shape = new TensorShapeProto(input_shape);
node_def.Attr["shape"] = shape;
}
if (!fdef.ArgAttr.ContainsKey((uint)i))
{
fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs();
}
var arg_attrs = fdef.ArgAttr[(uint)i].Attr;
foreach(var k in arg_attrs.Keys)
{
if(k == "_output_shapes")
{
if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List)
{
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]);
}
else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape)
{
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape);
}
}
else if (k.StartsWith("_"))
{
if (!node_def.Attr.ContainsKey(k))
{
node_def.Attr[k] = new AttrValue();
}
node_def.Attr[k] = new AttrValue(arg_attrs[k]);
}
}

graph_def.Node.Add(node_def);
}

graph_def.Node.AddRange(fdef.NodeDef);

Dictionary<string, string> nested_to_flat_tensor_name = new();
foreach(var arg_def in fdef.Signature.InputArg)
{
nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0";
string control_name = "^" + arg_def.Name;
nested_to_flat_tensor_name[control_name] = control_name;
}

foreach(var node_def in fdef.NodeDef)
{
var graph = default_graph;
// TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future.
while(graph.OuterGraph is not null)
{
graph = graph.OuterGraph;
}

var op_def = default_graph.GetOpDef(node_def.Op);

foreach(var attr in op_def.Attr)
{
if(attr.Type == "func")
{
var fname = node_def.Attr[attr.Name].Func.Name;
if (!is_function(fname))
{
throw new ValueError($"Function {fname} was not found. Please make sure " +
$"the FunctionDef `fdef` is correct.");
}
}
else if(attr.Type == "list(func)")
{
foreach(var fn in node_def.Attr[attr.Name].List.Func)
{
var fname = fn.Name;
if (!is_function(fname))
{
throw new ValueError($"Function {fname} was not found. Please make " +
$"sure the FunctionDef `fdef` is correct.");
}
}
}
}

int flattened_index = 0;
foreach(var arg_def in op_def.OutputArg)
{
var num_args = _get_num_args(arg_def, node_def);
for(int i = 0; i < num_args; i++)
{
var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}";
var flat_name = $"{node_def.Name}:{flattened_index}";
nested_to_flat_tensor_name[nested_name] = flat_name;
flattened_index++;
}
}
string control_name = "^" + node_def.Name;
nested_to_flat_tensor_name[control_name] = control_name;
}

foreach(var node_def in graph_def.Node)
{
for(int i = 0; i < node_def.Input.Count; i++)
{
node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]];
}
}

return (graph_def, nested_to_flat_tensor_name);
}

private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef)
{
foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg)))
{
if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0)
{
tensor.shape = Shape.Scalar;

var shape_and_type = arg_def.HandleData[0];
var handle_data = new HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new HandleShapeAndType()
{
Shape = shape_and_type.Shape,
Dtype = shape_and_type.Dtype
});
resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true);
}
}
}

private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def)
{
if (!string.IsNullOrEmpty(arg_def.NumberAttr))
{
return node_def.Attr[arg_def.NumberAttr].I;
}
else if(!string.IsNullOrEmpty(arg_def.TypeListAttr))
{
return node_def.Attr[arg_def.TypeListAttr].List.Type.Count;
}
else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid)
{
return 1;
}
else
{
throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " +
$"FunctionDef `fdef` is correct.");
}
}

public static bool is_function(string fname)
{
if (tf.Context.executing_eagerly())
{
return tf.Context.has_function(fname);
}
else
{
var graph = ops.get_default_graph();
while(graph is not null)
{
if (graph.IsFunction(fname))
{
return true;
}
if(graph.OuterGraph is not null)
{
graph = graph.OuterGraph;
}
else
{
return false;
}
}
}
throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}

+ 102
- 8
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -17,6 +17,7 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;
@@ -25,9 +26,14 @@ namespace Tensorflow
{
public class importer
{
public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null)
{
return import_graph_def(graph_def, validate_colocation_constraints: false, name: name);
}
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
bool validate_colocation_constraints = true,
string name = null,
OpList producer_op_list = null)
{
@@ -60,7 +66,7 @@ namespace Tensorflow
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
var status = new Status();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints );
// need to create a class ImportGraphDefWithResults with IDisposal
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status));
status.Check(true);
@@ -73,6 +79,42 @@ namespace Tensorflow
return _GatherReturnElements(return_elements, graph, results);
}

//private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null,
// bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null)
//{
// graph_def = _ProcessGraphDefParam(graph_def);
// input_map = _ProcessInputMapParam(input_map);
// return_elements = _ProcessReturnElementsParam(return_elements);

// if(producer_op_list is not null)
// {
// _RemoveDefaultAttrs(producer_op_list, graph_def);
// }

// var graph = ops.get_default_graph();
// string prefix = null;
// tf_with(ops.name_scope(name, "import", input_map.Values), scope =>
// {
// if (scope is not null)
// {
// Debug.Assert(scope.scope_name.EndsWith("/"));
// prefix = scope.scope_name[scope.scope_name.Length - 1].ToString();
// }
// else
// {
// prefix = "";
// }

// input_map = _ConvertInputMapValues(name, input_map);
// });

// var scope_options = c_api_util.ScopedTFImportGraphDefOptions();
// var options = scope_options.Options;
// _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints);
//}

private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
Graph graph,
TF_ImportGraphDefResults results)
@@ -113,15 +155,29 @@ namespace Tensorflow
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
string prefix,
Dictionary<string, Tensor> input_map,
string[] return_elements)
string[] return_elements,
bool validate_colocation_constraints)
{
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true);

foreach (var input in input_map)
{
var (src_name, src_index) = _ParseTensorName(input.Key);
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output());
var input_src = tf.compat.as_str(input.Key);
var input_dst = input.Value;
if (input_src.StartsWith("^"))
{
var src_name = tf.compat.as_str(input_src.Substring(1));
var dst_op = input_dst._as_tf_output().oper;
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op);
}
else
{
var (src_name, src_index) = _ParseTensorName(input.Key);
src_name = tf.compat.as_str(src_name);
var dst_output = input_dst._as_tf_output();
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output);
}
}

if (return_elements == null)
@@ -132,15 +188,16 @@ namespace Tensorflow
if (name.Contains(":"))
{
var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
op_name = tf.compat.as_str(op_name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index);
}
else
{
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name));
}
}

// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints);
}

private static (string, int) _ParseTensorName(string tensor_name)
@@ -173,6 +230,14 @@ namespace Tensorflow
return graph_def;
}

private static GraphDef _ProcessGraphDefParam(GraphDef graph_def)
{
var old_graph_def = graph_def;
graph_def = new GraphDef(old_graph_def);

return graph_def;
}

private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
{
foreach (var attr_def in op_def.Attr)
@@ -240,6 +305,35 @@ namespace Tensorflow
}
}

private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def)
{
var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x);

foreach (var node in graph_def.Node)
{
// Remove any default attr values that aren't in op_def.
if (producer_op_dict.ContainsKey(node.Op))
{
var op_def = op_def_registry.GetOpDef(node.Op);
if(op_def is null)
{
continue;
}
var producer_op_def = producer_op_dict[node.Op];
foreach (var key in node.Attr.Keys)
{
if (_FindAttrInOpDef(key, op_def) is null)
{
var attr_def = _FindAttrInOpDef(key, producer_op_def);
if (attr_def != null && attr_def.DefaultValue != null &&
node.Attr[key] == attr_def.DefaultValue)
node.Attr[key].ClearValue();
}
}
}
}
}

private static AttrDef _FindAttrInOpDef(string name, OpDef op_def)
{
return op_def.Attr.FirstOrDefault(x => x.Name == name);


+ 12
- 0
src/TensorFlowNET.Core/Framework/versions.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Framework
{
public class versions
{
public static int GRAPH_DEF_VERSION = 1286;
public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0;
}
}

+ 10
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -13,6 +13,7 @@ namespace Tensorflow.Functions
/// </summary>
public class ConcreteFunction: Trackable
{
protected IEnumerable<Tensor> _captured_inputs;
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
@@ -29,11 +30,13 @@ namespace Tensorflow.Functions
public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
_captured_inputs = func_graph.external_captures;
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
{
func_graph = graph;
_captured_inputs = func_graph.external_captures;

ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
}
@@ -53,6 +56,7 @@ namespace Tensorflow.Functions
new[] { output },
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
}

public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
@@ -73,6 +77,7 @@ namespace Tensorflow.Functions
new[] { output.variant_tensor },
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
}

/*public ConcreteFunction(Func<Tensors, Tensors> func,
@@ -174,6 +179,11 @@ namespace Tensorflow.Functions
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
}

public void SetExternalCaptures(IEnumerable<Tensor> captures)
{
_captured_inputs = captures;
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);


+ 1
- 0
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -1,4 +1,5 @@
using System;
using Tensorflow.Functions;
using Tensorflow.Train;

namespace Tensorflow


+ 12
- 0
src/TensorFlowNET.Core/Functions/IGenericFunction.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Functions
{
public interface IGenericFunction
{
object[] Apply(params object[] args);
ConcreteFunction get_concrete_function(params object[] args);
}
}

+ 88
- 0
src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs View File

@@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
{
public static class function_saved_model_utils
{
/// <summary>
///
/// </summary>
/// <param name="concrete_function"></param>
/// <param name="inputs">a list tensors or other objects (such as variables) which
/// contain tensors that were originally captured by the function</param>
public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs)
{
var bound_inputs = inputs?.Select(obj =>
{
if(obj is Tensor tensor)
{
return get_tensor_from_node(tensor);
}
else if(obj is IVariableV1 variable)
{
return get_tensor_from_node(variable);
}
else
{
throw new TypeError("Encountered an type error, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
});
var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1);

List<Tensor> captured_inputs_list = new();
// TODO(Rinne): concrete_function.set_variables(bound_variables)


if (bound_inputs is not null)
{
foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count())))
{
if(hasattr(bound_input, "__tf_experimental_restore_capture__"))
{
throw new NotImplementedException();
}
else
{
captured_inputs_list.Add(bound_input);
concrete_function.func_graph.replace_capture(bound_input, internal_capture);
if(internal_capture.dtype == dtypes.resource)
{
// skip the check of variable.
handle_data_util.copy_handle_data(bound_input, internal_capture);
}
concrete_function.func_graph.capture(bound_input);
}
}
}

if(captured_inputs_list.Any(inp => inp is null))
{
// TODO(Rinne): add warnings.
}
concrete_function.SetExternalCaptures(captured_inputs_list);
}

public static Tensor get_tensor_from_node(Tensor node)
{
return node;
}
public static Tensor get_tensor_from_node(IVariableV1 node)
{
if (resource_variable_ops.is_resource_variable(node))
{
return node.Handle;
}
else
{
throw new TypeError("Encountered an type error, please submit an issue to " +
"https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Gradients/custom_gradient.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
public class custom_gradient
{
public static string generate_name()
{
return $"CustomGradient-{ops.uid()}";
}
}
}

+ 3
- 1
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -1,6 +1,7 @@
using MethodBoundaryAspect.Fody.Attributes;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Functions;
@@ -21,8 +22,9 @@ namespace Tensorflow.Graphs

public override void OnEntry(MethodExecutionArgs args)
{
File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc");
// TODO: func_name can be cache in FullName + Args
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}";
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}";

if (functions.ContainsKey(func_name))
{


+ 5
- 0
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -56,6 +56,11 @@ public class FuncGraph : Graph, IDisposable
_handle = handle;
}

public void replace_capture(Tensor tensor, Tensor placeholder)
{
_captures[tensor.Id] = (tensor, placeholder);
}

public void ToGraph(Operation[] opers,
Tensor[] inputs, Tensor[] outputs,
string[] output_names)


+ 6
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -146,6 +146,12 @@ namespace Tensorflow
return ops.set_default_graph(this);
}

public bool IsFunction(string name)
{
// TODO(Rinne): deal with `_functions`.
throw new NotImplementedException();
}

private Tensor _as_graph_element(object obj)
{
if (obj is RefVariable var)


+ 2
- 0
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions
_handle = c_api.TF_NewImportGraphDefOptions();
}

public SafeImportGraphDefOptionsHandle Options => _handle;

public void AddReturnOutput(string name, int index)
{
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);


+ 4
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -185,6 +185,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name);

[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints);

/// <summary>
/// Add an output in `graph_def` to be returned via the `return_outputs` output
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input
@@ -246,7 +249,7 @@ namespace Tensorflow
/// <param name="ops">TF_ImportGraphDefOptions*</param>
/// <param name="uniquify_prefix">unsigned char</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix);
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix);

/// <summary>
/// Fetches the return operations requested via


+ 28
- 0
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow.Operations
{
public static class handle_data_util
{
public static void copy_handle_data(Tensor source_t, Tensor target_t)
{
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant)
{
SafeTensorHandle handle_data;
if(source_t is EagerTensor)
{
handle_data = source_t.Handle;
}
else
{
handle_data = ops.get_resource_handle_data(source_t);
}
throw new NotImplementedException();
//if(handle_data is not null && handle_data.)
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -126,7 +126,7 @@ namespace Tensorflow
/// <param name="handle"></param>
/// <param name="handle_data"></param>
/// <param name="graph_mode"></param>
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode)
{
if (!graph_mode)
return;


+ 12
- 3
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -5,6 +5,7 @@
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using Tensorflow.Framework.Models;
using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
@@ -2589,9 +2590,17 @@ namespace Tensorflow {
}
}

#region Nested types
/// <summary>Container for nested types declared in the FunctionSpec message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
//public static FunctionSpec from_function_and_signature(string csharp_function, IEnumerable<TensorSpec> input_signature, bool is_pure = false, object jit_compile = null)
//{
// // TODO(Rinne): _validate_signature(input_signature)
// // TODO(Rinne): _validate_python_function(python_function, input_signature)


//}

#region Nested types
/// <summary>Container for nested types declared in the FunctionSpec message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static partial class Types {
/// <summary>
/// Whether the function should be compiled by XLA.


+ 344
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -1,14 +1,24 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.RegularExpressions;
using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Training.Saving.SavedModel
{
public static class function_deserialization
{
private static string _INFERENCE_PREFIX = "__inference_";
private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$";
/// <summary>
/// Creates a `Function` from a `SavedFunction`.
/// </summary>
@@ -22,6 +32,338 @@ namespace Tensorflow.Training.Saving.SavedModel
return null;
}

public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library,
SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null)
{
var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct();
Dictionary<string, ConcreteFunction> functions = new();
Dictionary<string, ConcreteFunction> renamed_functions = new();

Graph graph;
if (ops.executing_eagerly_outside_functions())
{
graph = new Graph();
}
else
{
graph = ops.get_default_graph();
}

if(load_shared_name_suffix is null)
{
load_shared_name_suffix = $"_load_{ops.uid()}";
}

Dictionary<ByteString, string> library_gradient_names = new();
Dictionary<ByteString, string> new_gradient_op_types = new();
Dictionary<string, string> gradients_to_register = new();
foreach (var gdef in library.RegisteredGradients)
{
if(gdef.RegisteredOpType is not null)
{
var new_op_type = custom_gradient.generate_name();
var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType);

library_gradient_names[old_op_type] = gdef.GradientFunc;
new_gradient_op_types[old_op_type] = new_op_type;
gradients_to_register[gdef.GradientFunc] = new_op_type;
}
}

Dictionary<string, IEnumerable<string>> function_deps = new();
foreach(var fdef in library.Function)
{
function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names);
}

Dictionary<string, ConcreteFunction> loaded_gradients = new();
int aa = 0;
var temp = _sort_function_defs(library, function_deps);
foreach (var fdef in temp)
{
aa++;
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);

if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
{
// TODO(Rinne): implement it.
//var proto = saved_object_graph.ConcreteFunctions[orig_name];
//throw new NotImplementedException();
}

graph.as_default();
var func_graph = function_def_lib.function_def_to_graph(fdef, null, null);
graph.Exit();

_restore_gradient_functions(func_graph, renamed_functions, loaded_gradients);

foreach(var dep in function_deps[orig_name])
{
functions[dep].AddTograph(func_graph);
}

if (fdef.Attr.ContainsKey("_input_shapes"))
{
fdef.Attr.Remove("_input_shapes");
}
var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString()));
if(wrapper_function is not null)
{
throw new NotImplementedException();
}
func.AddTograph(graph);

functions[orig_name] = func;
renamed_functions[func.Name] = func;
if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp"))
{
func.AddTograph(ops.get_default_graph());
}

if (gradients_to_register.ContainsKey(orig_name))
{
var gradient_op_type = gradients_to_register[orig_name];
loaded_gradients[gradient_op_type] = func;
// TODO(Rinne): deal with gradient registry.
//new RegisteredGradient() { RegisteredOpType = gradient_op_type }.
}
}
return functions;
}

public static void fix_node_def(NodeDef node_def, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix)
{
if (functions.ContainsKey(node_def.Op))
{
node_def.Op = functions[node_def.Op].Name;
}
foreach(var attr_value in node_def.Attr.Values)
{
if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func)
{
attr_value.Func.Name = functions[attr_value.Func.Name].Name;
}
else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List)
{
foreach(var fn in attr_value.List.Func)
{
fn.Name = functions[fn.Name].Name;
}
}
}

if(node_def.Op == "HashTableV2")
{
if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B)
{
node_def.Attr["use_node_name_sharing"].B = true;
shared_name_suffix += $"_{ops.uid()}";
}
}

var op_def = op_def_registry.GetOpDef(node_def.Op);
if(op_def is not null)
{
var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault();
if(attr is not null)
{
ByteString shared_name = null;
if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null)
{
shared_name = node_def.Attr["shared_name"].S;
}
else if(attr.DefaultValue.S is not null)
{
shared_name = tf.compat.as_bytes(attr.DefaultValue.S);
}
if(shared_name is null)
{
shared_name = tf.compat.as_bytes(node_def.Name);
}
node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray());
}
}
}

private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients)
{
foreach(var op in func_graph.get_operations())
{
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()];
// TODO(Rinne): deal with `op._gradient_function`.
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch(Exception e)
{
continue;
}
if (loaded_gradients.ContainsKey(gradient_op_type))
{
var grad_fn = loaded_gradients[gradient_op_type];
grad_fn.NumPositionArgs = op.op.inputs.Length;
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
}
}
}

private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix,
IDictionary<ByteString, string> new_gradient_op_types)
{
var orig_name = fdef.Signature.Name;
bool contains_unsaved_custom_gradients = false;

foreach(var node_def in fdef.NodeDef)
{
fix_node_def(node_def, functions, shared_name_suffix);
var op_type = _get_gradient_op_type(node_def);
if(op_type is not null)
{
if (new_gradient_op_types.ContainsKey(op_type))
{
node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]);
}
else
{
contains_unsaved_custom_gradients = true;
}
}
}
if (contains_unsaved_custom_gradients)
{
// TODO(Rinne): log warnings.
}

fdef.Signature.Name = _clean_function_name(fdef.Signature.Name);
return orig_name;
}

private static string _clean_function_name(string name)
{
var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX);
if(match.Success)
{
return match.Groups[1].Value;
}
else
{
return name;
}
}

/// <summary>
/// Return a topologic sort of FunctionDefs in a library.
/// </summary>
/// <param name="library"></param>
/// <param name="function_deps"></param>
private static IEnumerable<FunctionDef> _sort_function_defs(FunctionDefLibrary library, Dictionary<string, IEnumerable<string>> function_deps)
{
Dictionary<string, IList<string>> edges = new();
Dictionary<string, int> in_count = new();
foreach(var item in function_deps)
{
var fname = item.Key;
var deps = item.Value;
if(deps is null || deps.Count() == 0)
{
in_count[fname] = 0;
continue;
}
foreach(var dep in deps)
{
edges.SetDefault(dep, new List<string>()).Add(fname);
if (in_count.ContainsKey(fname))
{
in_count[fname]++;
}
else
{
in_count[fname] = 1;
}
}
}
var ready = new Stack<string>(library.Function.
Where(x => in_count[x.Signature.Name] == 0)
.Select(x => x.Signature.Name).ToList());
List<string> output = new();
while(ready.Count > 0)
{
var node = ready.Pop();
output.Add(node);
if (!edges.ContainsKey(node))
{
continue;
}
foreach(var dest in edges[node])
{
in_count[dest] -= 1;
if (in_count[dest] == 0)
{
ready.Push(dest);
}
}
}

if(output.Count != library.Function.Count)
{
var failed_to_resolve = in_count.Keys.Except(output);
throw new ValueError($"There is a cyclic dependency between functions. " +
$"Could not resolve ({string.Join(", ", failed_to_resolve)}).");
}

var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x);
return output.Select(x => reverse[x]);
}

private static IEnumerable<string> _list_function_deps(FunctionDef fdef, IEnumerable<string> library_function_names, IDictionary<ByteString, string> library_gradient_names)
{
HashSet<string> deps = new HashSet<string>();
foreach(var node_def in fdef.NodeDef)
{
var grad_op_type = _get_gradient_op_type(node_def);
if (library_function_names.Contains(node_def.Op))
{
deps.Add(node_def.Op);
}
else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name))
{
deps.Add(gradient_name);
}
else
{
foreach(var attr_value in node_def.Attr.Values)
{
if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func)
{
deps.Add(attr_value.Func.Name);
}
else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List)
{
foreach(var fn in attr_value.List.Func)
{
deps.Add(fn.Name);
}
}
}
}
}
return deps.AsEnumerable();
}

private static ByteString _get_gradient_op_type(NodeDef node_def)
{
if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall")
{
return node_def.Attr["_gradient_op_type"].S;
}
return null;
}

public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions)
{
@@ -30,6 +372,7 @@ namespace Tensorflow.Training.Saving.SavedModel
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;

var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
// TODO(Rinne): set the functiona spec.
concrete_function.AddTograph();
return concrete_function;
}


+ 19
- 5
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -35,6 +35,8 @@ namespace Tensorflow
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes;
private List<Trackable> _nodes;
private Dictionary<int, Action<object, object, object>> _node_setters;
private Dictionary<string, ConcreteFunction> _concrete_functions;
private HashSet<string> _restored_concrete_functions;
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir,
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters)
{
@@ -44,6 +46,9 @@ namespace Tensorflow
_proto = object_graph_proto;
_export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
_concrete_functions = function_deserialization.load_function_def_library(
meta_graph.GraphDef.Library, _proto);
_restored_concrete_functions = new HashSet<string>();
_checkpoint_options = ckpt_options;
_save_options = save_options;

@@ -464,9 +469,17 @@ namespace Tensorflow
}
}

private void _setup_function_captures()
private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes)
{
// TODO: implement it with concrete functions.
if (_restored_concrete_functions.Contains(concrete_function_name))
{
return;
}
_restored_concrete_functions.Add(concrete_function_name);
var concrete_function = _concrete_functions[concrete_function_name];
var proto = _proto.ConcreteFunctions[concrete_function_name];
var inputs = proto.BoundInputs.Select(x => nodes[x]);
function_saved_model_utils.restore_captures(concrete_function, inputs);
}

private void _setup_remaining_functions()
@@ -625,7 +638,7 @@ namespace Tensorflow
var fn = function_deserialization.recreate_function(proto, null);
foreach (var name in proto.ConcreteFunctions)
{
_setup_function_captures();
_setup_function_captures(name, dependencies);
}
return (fn, setattr);
}
@@ -633,8 +646,9 @@ namespace Tensorflow
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
{
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions);
_setup_function_captures(proto.ConcreteFunctionName, dependencies);
return (fn, setattr);
}

// TODO: remove this to a common class.


+ 14
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/nested_structure_coder.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Training.Saving.SavedModel
{
//public class nested_structure_coder
//{
// public static List<object> decode_proto(StructuredValue proto)
// {
// return proto s
// }
//}
}

+ 5
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -572,6 +572,11 @@ namespace Tensorflow
return get_default_graph().building_function;
}

public static SafeTensorHandle get_resource_handle_data(Tensor graph_op)
{
throw new NotImplementedException();
}

public static void dismantle_graph(Graph graph)
{


+ 13
- 0
src/TensorFlowNET.Keras/Saving/KerasMetaData.cs View File

@@ -8,9 +8,14 @@ namespace Tensorflow.Keras.Saving
{
public class KerasMetaData
{
[JsonProperty("name")]
public string Name { get; set; }
[JsonProperty("class_name")]
public string ClassName { get; set; }
[JsonProperty("trainable")]
public bool Trainable { get; set; }
[JsonProperty("dtype")]
public TF_DataType DType { get; set; } = TF_DataType.DtInvalid;
[JsonProperty("is_graph_network")]
public bool IsGraphNetwork { get; set; }
[JsonProperty("shared_object_id")]
@@ -20,5 +25,13 @@ namespace Tensorflow.Keras.Saving
public JObject Config { get; set; }
[JsonProperty("build_input_shape")]
public TensorShapeConfig BuildInputShape { get; set; }
[JsonProperty("batch_input_shape")]
public TensorShapeConfig BatchInputShape { get; set; }
[JsonProperty("activity_regularizer")]
public IRegularizer ActivityRegularizer { get; set; }
[JsonProperty("input_spec")]
public JToken InputSpec { get; set; }
[JsonProperty("stateful")]
public bool? Stateful { get; set; }
}
}

+ 18
- 4
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving
{
public class KerasObjectLoader
{
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects;
internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects;
private SavedMetadata _metadata;
private SavedObjectGraph _proto;
private Dictionary<int, string> _node_paths = new Dictionary<int, string>();
@@ -311,6 +311,10 @@ namespace Tensorflow.Keras.Saving
{
(obj, setter) = _revive_custom_object(identifier, metadata);
}
if(obj is null)
{
throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object.");
}
Debug.Assert(obj is Layer);
_maybe_add_serialized_attributes(obj as Layer, metadata);
return (obj, setter);
@@ -349,8 +353,14 @@ namespace Tensorflow.Keras.Saving

private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER)
{
return RevivedLayer.init_from_metadata(metadata);
}
else
{
throw new NotImplementedException();
}
}

Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id)
@@ -403,9 +413,13 @@ namespace Tensorflow.Keras.Saving

var obj = generic_utils.deserialize_keras_object(class_name, config);

if(obj is null)
{
return null;
}
obj.Name = metadata.Name;
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec`

var built = _try_build_layer(obj, node_id, metadata.BuildInputShape);
if (!built)


+ 62
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs View File

@@ -0,0 +1,62 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.Text.RegularExpressions;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel
{
internal static class ReviveUtils
{
public static T recursively_deserialize_keras_object<T>(JToken config)
{
throw new NotImplementedException();
if(config is JObject jobject)
{
if (jobject.ContainsKey("class_name"))
{
}
}
}

public static void _revive_setter(object layer, object name, object value)
{
Debug.Assert(name is string);
Debug.Assert(layer is Layer);
if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string))
{
if (value is Trackable trackable)
{
(layer as Layer)._track_trackable(trackable, name as string);
}
(layer as Layer).SerializedAttributes[name] = JToken.FromObject(value);
}
else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
{
Debug.Assert(value is Trackable);
functional._track_trackable(value as Trackable, name as string);
}
else
{
var properties = layer.GetType().GetProperties();
foreach (var p in properties)
{
if ((string)name == p.Name)
{
if(p.GetValue(layer) is not null)
{
return;
}
p.SetValue(layer, value);
return;
}
}
}
}
}
}

+ 37
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedConfig.cs View File

@@ -0,0 +1,37 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving.SavedModel
{
[JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))]
public class RevivedConfig: IKerasConfig
{
public JObject Config { get; set; }
}

public class CustomizedRevivedConfigJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(RevivedConfig);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
((RevivedConfig)value).Config.WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var config = (JObject)serializer.Deserialize(reader, typeof(JObject));
return new RevivedConfig() { Config = config };
}
}
}

+ 73
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs View File

@@ -0,0 +1,73 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedLayer: Layer
{
public static (RevivedLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
{
LayerArgs args = new LayerArgs()
{
Name = metadata.Name,
Trainable = metadata.Trainable
};
if(metadata.DType != TF_DataType.DtInvalid)
{
args.DType = metadata.DType;
}
if(metadata.BatchInputShape is not null)
{
args.BatchInputShape = metadata.BatchInputShape;
}

RevivedLayer revived_obj = new RevivedLayer(args);

// TODO(Rinne): implement `expects_training_arg`.
var config = metadata.Config;
if (generic_utils.validate_config(config))
{
revived_obj._config = new RevivedConfig() { Config = config };
}
if(metadata.InputSpec is not null)
{
throw new NotImplementedException();
}
if(metadata.ActivityRegularizer is not null)
{
throw new NotImplementedException();
}
// TODO(Rinne): `_is_feature_layer`
if(metadata.Stateful is not null)
{
revived_obj.stateful = metadata.Stateful.Value;
}

return (revived_obj, ReviveUtils._revive_setter);
}

private RevivedConfig _config = null;

public RevivedLayer(LayerArgs args): base(args)
{

}

public override string ToString()
{
return $"Customized keras layer: {Name}.";
}

public override IKerasConfig get_config()
{
return _config;
}
}
}

+ 9
- 0
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -23,6 +23,7 @@ using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Security.AccessControl;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
@@ -60,6 +61,10 @@ namespace Tensorflow.Keras.Utils
public static Layer deserialize_keras_object(string class_name, JToken config)
{
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
if(argType is null)
{
return null;
}
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
@@ -72,6 +77,10 @@ namespace Tensorflow.Keras.Utils
public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
if (layer is null)
{
return null;
}
Debug.Assert(layer is Layer);
return layer as Layer;
}


+ 9
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -1,10 +1,12 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

@@ -56,4 +58,11 @@ public class SequentialModelLoad

model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
}

[TestMethod]
public void Temp()
{
var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func");
model.summary();
}
}

Loading…
Cancel
Save