@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -45,6 +46,23 @@ namespace Tensorflow | |||
{ | |||
return as_text(bytes_or_text, encoding); | |||
} | |||
public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||
{ | |||
return bytes; | |||
} | |||
public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||
{ | |||
return ByteString.CopyFrom(bytes); | |||
} | |||
public ByteString as_bytes(string text, Encoding encoding = null) | |||
{ | |||
if(encoding is null) | |||
{ | |||
encoding = Encoding.UTF8; | |||
} | |||
return ByteString.CopyFrom(encoding.GetBytes(text)); | |||
} | |||
} | |||
public bool executing_eagerly() | |||
@@ -54,6 +54,6 @@ namespace Tensorflow | |||
Dictionary<string, Tensor> input_map = null, | |||
string[] return_elements = null, | |||
string name = null, | |||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||
} | |||
} |
@@ -156,6 +156,12 @@ namespace Tensorflow.Contexts | |||
return has_graph_arg; | |||
} | |||
public bool has_function(string name) | |||
{ | |||
ensure_initialized(); | |||
return c_api.TFE_ContextHasFunction(_handle, name); | |||
} | |||
public void restore_mode() | |||
{ | |||
context_switches.Pop(); | |||
@@ -0,0 +1,288 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Security.Cryptography; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
namespace Tensorflow.Framework | |||
{ | |||
public class function_def_lib | |||
{ | |||
// TODO(Rinne): process signatures and structured outputs. | |||
public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||
object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||
{ | |||
var func_graph = new FuncGraph(fdef.Signature.Name); | |||
if(input_shapes is null) | |||
{ | |||
if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||
{ | |||
var raw_input_shapes = input_shapes_attr.List.Shape; | |||
input_shapes = new List<TensorShapeProto>(); | |||
foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||
{ | |||
if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
{ | |||
input_shapes.Add(null); | |||
} | |||
else | |||
{ | |||
input_shapes.Add(input_shape); | |||
} | |||
} | |||
} | |||
} | |||
var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||
func_graph.as_default(); | |||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
// TODO(Rinne): func_graph.ControlOutputs | |||
_set_handle_data(func_graph, fdef); | |||
foreach(var node in graph_def.Node) | |||
{ | |||
if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||
{ | |||
var op = func_graph.get_operation_by_name(node.Name); | |||
foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||
{ | |||
op.outputs[output_index].shape = new Shape(shape); | |||
} | |||
} | |||
} | |||
Dictionary<long, string> output_names = new(); | |||
foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||
{ | |||
output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||
} | |||
// TODO(Rinne): func_graph._output_names = output_names | |||
func_graph.Exit(); | |||
return func_graph; | |||
} | |||
public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||
{ | |||
var graph_def = new GraphDef() | |||
{ | |||
Versions = new VersionDef() | |||
{ | |||
Producer = versions.GRAPH_DEF_VERSION, | |||
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
} | |||
}; | |||
var default_graph = ops.get_default_graph(); | |||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||
{ | |||
throw new ValueError($"Length of `input_shapes` must match the number " + | |||
$"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||
$"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||
} | |||
foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||
{ | |||
NodeDef node_def = new(); | |||
node_def.Name = arg_def.Name; | |||
node_def.Op = "Placeholder"; | |||
node_def.Attr["dtype"] = new AttrValue() | |||
{ | |||
Type = arg_def.Type | |||
}; | |||
if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||
{ | |||
var input_shape = input_shapes[i]; | |||
// skip the condition that input_shape is not `TensorShapeProto`. | |||
AttrValue shape = new AttrValue() | |||
{ | |||
Shape = new TensorShapeProto() | |||
}; | |||
shape.Shape = new TensorShapeProto(input_shape); | |||
node_def.Attr["shape"] = shape; | |||
} | |||
if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||
{ | |||
fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||
} | |||
var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||
foreach(var k in arg_attrs.Keys) | |||
{ | |||
if(k == "_output_shapes") | |||
{ | |||
if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||
} | |||
else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||
{ | |||
node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||
} | |||
} | |||
else if (k.StartsWith("_")) | |||
{ | |||
if (!node_def.Attr.ContainsKey(k)) | |||
{ | |||
node_def.Attr[k] = new AttrValue(); | |||
} | |||
node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||
} | |||
} | |||
graph_def.Node.Add(node_def); | |||
} | |||
graph_def.Node.AddRange(fdef.NodeDef); | |||
Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||
foreach(var arg_def in fdef.Signature.InputArg) | |||
{ | |||
nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||
string control_name = "^" + arg_def.Name; | |||
nested_to_flat_tensor_name[control_name] = control_name; | |||
} | |||
foreach(var node_def in fdef.NodeDef) | |||
{ | |||
var graph = default_graph; | |||
// TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. | |||
while(graph.OuterGraph is not null) | |||
{ | |||
graph = graph.OuterGraph; | |||
} | |||
var op_def = default_graph.GetOpDef(node_def.Op); | |||
foreach(var attr in op_def.Attr) | |||
{ | |||
if(attr.Type == "func") | |||
{ | |||
var fname = node_def.Attr[attr.Name].Func.Name; | |||
if (!is_function(fname)) | |||
{ | |||
throw new ValueError($"Function {fname} was not found. Please make sure " + | |||
$"the FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
else if(attr.Type == "list(func)") | |||
{ | |||
foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||
{ | |||
var fname = fn.Name; | |||
if (!is_function(fname)) | |||
{ | |||
throw new ValueError($"Function {fname} was not found. Please make " + | |||
$"sure the FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
} | |||
} | |||
int flattened_index = 0; | |||
foreach(var arg_def in op_def.OutputArg) | |||
{ | |||
var num_args = _get_num_args(arg_def, node_def); | |||
for(int i = 0; i < num_args; i++) | |||
{ | |||
var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||
var flat_name = $"{node_def.Name}:{flattened_index}"; | |||
nested_to_flat_tensor_name[nested_name] = flat_name; | |||
flattened_index++; | |||
} | |||
} | |||
string control_name = "^" + node_def.Name; | |||
nested_to_flat_tensor_name[control_name] = control_name; | |||
} | |||
foreach(var node_def in graph_def.Node) | |||
{ | |||
for(int i = 0; i < node_def.Input.Count; i++) | |||
{ | |||
node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||
} | |||
} | |||
return (graph_def, nested_to_flat_tensor_name); | |||
} | |||
private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||
{ | |||
foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||
{ | |||
if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
{ | |||
tensor.shape = Shape.Scalar; | |||
var shape_and_type = arg_def.HandleData[0]; | |||
var handle_data = new HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
{ | |||
Shape = shape_and_type.Shape, | |||
Dtype = shape_and_type.Dtype | |||
}); | |||
resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||
} | |||
} | |||
} | |||
private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||
{ | |||
if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||
{ | |||
return node_def.Attr[arg_def.NumberAttr].I; | |||
} | |||
else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||
{ | |||
return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||
} | |||
else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||
{ | |||
return 1; | |||
} | |||
else | |||
{ | |||
throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||
$"FunctionDef `fdef` is correct."); | |||
} | |||
} | |||
public static bool is_function(string fname) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return tf.Context.has_function(fname); | |||
} | |||
else | |||
{ | |||
var graph = ops.get_default_graph(); | |||
while(graph is not null) | |||
{ | |||
if (graph.IsFunction(fname)) | |||
{ | |||
return true; | |||
} | |||
if(graph.OuterGraph is not null) | |||
{ | |||
graph = graph.OuterGraph; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
} | |||
throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.OpDef.Types; | |||
@@ -25,9 +26,14 @@ namespace Tensorflow | |||
{ | |||
public class importer | |||
{ | |||
public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||
{ | |||
return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||
} | |||
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
Dictionary<string, Tensor> input_map = null, | |||
string[] return_elements = null, | |||
bool validate_colocation_constraints = true, | |||
string name = null, | |||
OpList producer_op_list = null) | |||
{ | |||
@@ -60,7 +66,7 @@ namespace Tensorflow | |||
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
var status = new Status(); | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
status.Check(true); | |||
@@ -73,6 +79,42 @@ namespace Tensorflow | |||
return _GatherReturnElements(return_elements, graph, results); | |||
} | |||
//private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null, | |||
// bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) | |||
//{ | |||
// graph_def = _ProcessGraphDefParam(graph_def); | |||
// input_map = _ProcessInputMapParam(input_map); | |||
// return_elements = _ProcessReturnElementsParam(return_elements); | |||
// if(producer_op_list is not null) | |||
// { | |||
// _RemoveDefaultAttrs(producer_op_list, graph_def); | |||
// } | |||
// var graph = ops.get_default_graph(); | |||
// string prefix = null; | |||
// tf_with(ops.name_scope(name, "import", input_map.Values), scope => | |||
// { | |||
// if (scope is not null) | |||
// { | |||
// Debug.Assert(scope.scope_name.EndsWith("/")); | |||
// prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); | |||
// } | |||
// else | |||
// { | |||
// prefix = ""; | |||
// } | |||
// input_map = _ConvertInputMapValues(name, input_map); | |||
// }); | |||
// var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
// var options = scope_options.Options; | |||
// _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); | |||
//} | |||
private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | |||
Graph graph, | |||
TF_ImportGraphDefResults results) | |||
@@ -113,15 +155,29 @@ namespace Tensorflow | |||
public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | |||
string prefix, | |||
Dictionary<string, Tensor> input_map, | |||
string[] return_elements) | |||
string[] return_elements, | |||
bool validate_colocation_constraints) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||
foreach (var input in input_map) | |||
{ | |||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
var input_src = tf.compat.as_str(input.Key); | |||
var input_dst = input.Value; | |||
if (input_src.StartsWith("^")) | |||
{ | |||
var src_name = tf.compat.as_str(input_src.Substring(1)); | |||
var dst_op = input_dst._as_tf_output().oper; | |||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||
} | |||
else | |||
{ | |||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||
src_name = tf.compat.as_str(src_name); | |||
var dst_output = input_dst._as_tf_output(); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||
} | |||
} | |||
if (return_elements == null) | |||
@@ -132,15 +188,16 @@ namespace Tensorflow | |||
if (name.Contains(":")) | |||
{ | |||
var (op_name, index) = _ParseTensorName(name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
op_name = tf.compat.as_str(op_name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||
} | |||
else | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||
} | |||
} | |||
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||
} | |||
private static (string, int) _ParseTensorName(string tensor_name) | |||
@@ -173,6 +230,14 @@ namespace Tensorflow | |||
return graph_def; | |||
} | |||
private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||
{ | |||
var old_graph_def = graph_def; | |||
graph_def = new GraphDef(old_graph_def); | |||
return graph_def; | |||
} | |||
private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | |||
{ | |||
foreach (var attr_def in op_def.Attr) | |||
@@ -240,6 +305,35 @@ namespace Tensorflow | |||
} | |||
} | |||
private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||
{ | |||
var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||
foreach (var node in graph_def.Node) | |||
{ | |||
// Remove any default attr values that aren't in op_def. | |||
if (producer_op_dict.ContainsKey(node.Op)) | |||
{ | |||
var op_def = op_def_registry.GetOpDef(node.Op); | |||
if(op_def is null) | |||
{ | |||
continue; | |||
} | |||
var producer_op_def = producer_op_dict[node.Op]; | |||
foreach (var key in node.Attr.Keys) | |||
{ | |||
if (_FindAttrInOpDef(key, op_def) is null) | |||
{ | |||
var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||
if (attr_def != null && attr_def.DefaultValue != null && | |||
node.Attr[key] == attr_def.DefaultValue) | |||
node.Attr[key].ClearValue(); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | |||
{ | |||
return op_def.Attr.FirstOrDefault(x => x.Name == name); | |||
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Framework | |||
{ | |||
public class versions | |||
{ | |||
public static int GRAPH_DEF_VERSION = 1286; | |||
public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||
} | |||
} |
@@ -13,6 +13,7 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
public class ConcreteFunction: Trackable | |||
{ | |||
protected IEnumerable<Tensor> _captured_inputs; | |||
internal FuncGraph func_graph; | |||
internal ForwardBackwardCall forward_backward; | |||
public Tensor[] Inputs => func_graph.Inputs; | |||
@@ -29,11 +30,13 @@ namespace Tensorflow.Functions | |||
public ConcreteFunction(string name) | |||
{ | |||
func_graph = new FuncGraph(name); | |||
_captured_inputs = func_graph.external_captures; | |||
} | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
{ | |||
func_graph = graph; | |||
_captured_inputs = func_graph.external_captures; | |||
ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
} | |||
@@ -53,6 +56,7 @@ namespace Tensorflow.Functions | |||
new[] { output }, | |||
null); | |||
func_graph.Exit(); | |||
_captured_inputs = func_graph.external_captures; | |||
} | |||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
@@ -73,6 +77,7 @@ namespace Tensorflow.Functions | |||
new[] { output.variant_tensor }, | |||
null); | |||
func_graph.Exit(); | |||
_captured_inputs = func_graph.external_captures; | |||
} | |||
/*public ConcreteFunction(Func<Tensors, Tensors> func, | |||
@@ -174,6 +179,11 @@ namespace Tensorflow.Functions | |||
// TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
} | |||
public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
{ | |||
_captured_inputs = captures; | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Train; | |||
namespace Tensorflow | |||
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Functions | |||
{ | |||
public interface IGenericFunction | |||
{ | |||
object[] Apply(params object[] args); | |||
ConcreteFunction get_concrete_function(params object[] args); | |||
} | |||
} |
@@ -0,0 +1,88 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
{ | |||
public static class function_saved_model_utils | |||
{ | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="concrete_function"></param> | |||
/// <param name="inputs">a list tensors or other objects (such as variables) which | |||
/// contain tensors that were originally captured by the function</param> | |||
public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||
{ | |||
var bound_inputs = inputs?.Select(obj => | |||
{ | |||
if(obj is Tensor tensor) | |||
{ | |||
return get_tensor_from_node(tensor); | |||
} | |||
else if(obj is IVariableV1 variable) | |||
{ | |||
return get_tensor_from_node(variable); | |||
} | |||
else | |||
{ | |||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
}); | |||
var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); | |||
List<Tensor> captured_inputs_list = new(); | |||
// TODO(Rinne): concrete_function.set_variables(bound_variables) | |||
if (bound_inputs is not null) | |||
{ | |||
foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||
{ | |||
if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
captured_inputs_list.Add(bound_input); | |||
concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||
if(internal_capture.dtype == dtypes.resource) | |||
{ | |||
// skip the check of variable. | |||
handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
} | |||
concrete_function.func_graph.capture(bound_input); | |||
} | |||
} | |||
} | |||
if(captured_inputs_list.Any(inp => inp is null)) | |||
{ | |||
// TODO(Rinne): add warnings. | |||
} | |||
concrete_function.SetExternalCaptures(captured_inputs_list); | |||
} | |||
public static Tensor get_tensor_from_node(Tensor node) | |||
{ | |||
return node; | |||
} | |||
public static Tensor get_tensor_from_node(IVariableV1 node) | |||
{ | |||
if (resource_variable_ops.is_resource_variable(node)) | |||
{ | |||
return node.Handle; | |||
} | |||
else | |||
{ | |||
throw new TypeError("Encountered an type error, please submit an issue to " + | |||
"https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public class custom_gradient | |||
{ | |||
public static string generate_name() | |||
{ | |||
return $"CustomGradient-{ops.uid()}"; | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using MethodBoundaryAspect.Fody.Attributes; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Functions; | |||
@@ -21,8 +22,9 @@ namespace Tensorflow.Graphs | |||
public override void OnEntry(MethodExecutionArgs args) | |||
{ | |||
File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); | |||
// TODO: func_name can be cache in FullName + Args | |||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||
if (functions.ContainsKey(func_name)) | |||
{ | |||
@@ -56,6 +56,11 @@ public class FuncGraph : Graph, IDisposable | |||
_handle = handle; | |||
} | |||
public void replace_capture(Tensor tensor, Tensor placeholder) | |||
{ | |||
_captures[tensor.Id] = (tensor, placeholder); | |||
} | |||
public void ToGraph(Operation[] opers, | |||
Tensor[] inputs, Tensor[] outputs, | |||
string[] output_names) | |||
@@ -146,6 +146,12 @@ namespace Tensorflow | |||
return ops.set_default_graph(this); | |||
} | |||
public bool IsFunction(string name) | |||
{ | |||
// TODO(Rinne): deal with `_functions`. | |||
throw new NotImplementedException(); | |||
} | |||
private Tensor _as_graph_element(object obj) | |||
{ | |||
if (obj is RefVariable var) | |||
@@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||
} | |||
public SafeImportGraphDefOptionsHandle Options => _handle; | |||
public void AddReturnOutput(string name, int index) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
@@ -185,6 +185,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||
/// <summary> | |||
/// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
/// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
@@ -246,7 +249,7 @@ namespace Tensorflow | |||
/// <param name="ops">TF_ImportGraphDefOptions*</param> | |||
/// <param name="uniquify_prefix">unsigned char</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||
public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||
/// <summary> | |||
/// Fetches the return operations requested via | |||
@@ -0,0 +1,28 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Operations | |||
{ | |||
public static class handle_data_util | |||
{ | |||
public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||
{ | |||
if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||
{ | |||
SafeTensorHandle handle_data; | |||
if(source_t is EagerTensor) | |||
{ | |||
handle_data = source_t.Handle; | |||
} | |||
else | |||
{ | |||
handle_data = ops.get_resource_handle_data(source_t); | |||
} | |||
throw new NotImplementedException(); | |||
//if(handle_data is not null && handle_data.) | |||
} | |||
} | |||
} | |||
} |
@@ -126,7 +126,7 @@ namespace Tensorflow | |||
/// <param name="handle"></param> | |||
/// <param name="handle_data"></param> | |||
/// <param name="graph_mode"></param> | |||
private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
{ | |||
if (!graph_mode) | |||
return; | |||
@@ -5,6 +5,7 @@ | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using Tensorflow.Framework.Models; | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
@@ -2589,9 +2590,17 @@ namespace Tensorflow { | |||
} | |||
} | |||
#region Nested types | |||
/// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
//public static FunctionSpec from_function_and_signature(string csharp_function, IEnumerable<TensorSpec> input_signature, bool is_pure = false, object jit_compile = null) | |||
//{ | |||
// // TODO(Rinne): _validate_signature(input_signature) | |||
// // TODO(Rinne): _validate_python_function(python_function, input_signature) | |||
//} | |||
#region Nested types | |||
/// <summary>Container for nested types declared in the FunctionSpec message type.</summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static partial class Types { | |||
/// <summary> | |||
/// Whether the function should be compiled by XLA. | |||
@@ -1,14 +1,24 @@ | |||
using System; | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
public static class function_deserialization | |||
{ | |||
private static string _INFERENCE_PREFIX = "__inference_"; | |||
private static string _FUNCTION_WRAPPER_NAME_REGEX = $@"^{_INFERENCE_PREFIX}(.*)_\d+$"; | |||
/// <summary> | |||
/// Creates a `Function` from a `SavedFunction`. | |||
/// </summary> | |||
@@ -22,6 +32,338 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
return null; | |||
} | |||
public static Dictionary<string, ConcreteFunction> load_function_def_library(FunctionDefLibrary library, | |||
SavedObjectGraph saved_object_graph = null, string load_shared_name_suffix = null, object? wrapper_function = null) | |||
{ | |||
var library_function_names = library.Function.Select(x => x.Signature.Name).Distinct(); | |||
Dictionary<string, ConcreteFunction> functions = new(); | |||
Dictionary<string, ConcreteFunction> renamed_functions = new(); | |||
Graph graph; | |||
if (ops.executing_eagerly_outside_functions()) | |||
{ | |||
graph = new Graph(); | |||
} | |||
else | |||
{ | |||
graph = ops.get_default_graph(); | |||
} | |||
if(load_shared_name_suffix is null) | |||
{ | |||
load_shared_name_suffix = $"_load_{ops.uid()}"; | |||
} | |||
Dictionary<ByteString, string> library_gradient_names = new(); | |||
Dictionary<ByteString, string> new_gradient_op_types = new(); | |||
Dictionary<string, string> gradients_to_register = new(); | |||
foreach (var gdef in library.RegisteredGradients) | |||
{ | |||
if(gdef.RegisteredOpType is not null) | |||
{ | |||
var new_op_type = custom_gradient.generate_name(); | |||
var old_op_type = tf.compat.as_bytes(gdef.RegisteredOpType); | |||
library_gradient_names[old_op_type] = gdef.GradientFunc; | |||
new_gradient_op_types[old_op_type] = new_op_type; | |||
gradients_to_register[gdef.GradientFunc] = new_op_type; | |||
} | |||
} | |||
Dictionary<string, IEnumerable<string>> function_deps = new(); | |||
foreach(var fdef in library.Function) | |||
{ | |||
function_deps[fdef.Signature.Name] = _list_function_deps(fdef, library_function_names, library_gradient_names); | |||
} | |||
Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
int aa = 0; | |||
var temp = _sort_function_defs(library, function_deps); | |||
foreach (var fdef in temp) | |||
{ | |||
aa++; | |||
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
{ | |||
// TODO(Rinne): implement it. | |||
//var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
//throw new NotImplementedException(); | |||
} | |||
graph.as_default(); | |||
var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); | |||
graph.Exit(); | |||
_restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); | |||
foreach(var dep in function_deps[orig_name]) | |||
{ | |||
functions[dep].AddTograph(func_graph); | |||
} | |||
if (fdef.Attr.ContainsKey("_input_shapes")) | |||
{ | |||
fdef.Attr.Remove("_input_shapes"); | |||
} | |||
var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); | |||
if(wrapper_function is not null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
func.AddTograph(graph); | |||
functions[orig_name] = func; | |||
renamed_functions[func.Name] = func; | |||
if(func_graph.get_operations().Any(op => op.op.type == "TRTEngineOp")) | |||
{ | |||
func.AddTograph(ops.get_default_graph()); | |||
} | |||
if (gradients_to_register.ContainsKey(orig_name)) | |||
{ | |||
var gradient_op_type = gradients_to_register[orig_name]; | |||
loaded_gradients[gradient_op_type] = func; | |||
// TODO(Rinne): deal with gradient registry. | |||
//new RegisteredGradient() { RegisteredOpType = gradient_op_type }. | |||
} | |||
} | |||
return functions; | |||
} | |||
public static void fix_node_def(NodeDef node_def, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix) | |||
{ | |||
if (functions.ContainsKey(node_def.Op)) | |||
{ | |||
node_def.Op = functions[node_def.Op].Name; | |||
} | |||
foreach(var attr_value in node_def.Attr.Values) | |||
{ | |||
if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||
{ | |||
attr_value.Func.Name = functions[attr_value.Func.Name].Name; | |||
} | |||
else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
foreach(var fn in attr_value.List.Func) | |||
{ | |||
fn.Name = functions[fn.Name].Name; | |||
} | |||
} | |||
} | |||
if(node_def.Op == "HashTableV2") | |||
{ | |||
if(!node_def.Attr.ContainsKey("use_node_name_sharing") || !node_def.Attr["use_node_name_sharing"].B) | |||
{ | |||
node_def.Attr["use_node_name_sharing"].B = true; | |||
shared_name_suffix += $"_{ops.uid()}"; | |||
} | |||
} | |||
var op_def = op_def_registry.GetOpDef(node_def.Op); | |||
if(op_def is not null) | |||
{ | |||
var attr = op_def.Attr.Where(x => x.Name == "shared_name").FirstOrDefault(); | |||
if(attr is not null) | |||
{ | |||
ByteString shared_name = null; | |||
if(node_def.Attr.ContainsKey("shared_name") && node_def.Attr["shared_name"].S is not null) | |||
{ | |||
shared_name = node_def.Attr["shared_name"].S; | |||
} | |||
else if(attr.DefaultValue.S is not null) | |||
{ | |||
shared_name = tf.compat.as_bytes(attr.DefaultValue.S); | |||
} | |||
if(shared_name is null) | |||
{ | |||
shared_name = tf.compat.as_bytes(node_def.Name); | |||
} | |||
node_def.Attr["shared_name"].S = ByteString.CopyFrom(shared_name.Concat(tf.compat.as_bytes(node_def.Name)).ToArray()); | |||
} | |||
} | |||
} | |||
private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||
{ | |||
foreach(var op in func_graph.get_operations()) | |||
{ | |||
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
{ | |||
var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; | |||
// TODO(Rinne): deal with `op._gradient_function`. | |||
} | |||
string gradient_op_type = null; | |||
try | |||
{ | |||
gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
} | |||
catch(Exception e) | |||
{ | |||
continue; | |||
} | |||
if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
{ | |||
var grad_fn = loaded_gradients[gradient_op_type]; | |||
grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
} | |||
} | |||
} | |||
private static string _fix_fdef_in_place(FunctionDef fdef, IDictionary<string, ConcreteFunction> functions, string shared_name_suffix, | |||
IDictionary<ByteString, string> new_gradient_op_types) | |||
{ | |||
var orig_name = fdef.Signature.Name; | |||
bool contains_unsaved_custom_gradients = false; | |||
foreach(var node_def in fdef.NodeDef) | |||
{ | |||
fix_node_def(node_def, functions, shared_name_suffix); | |||
var op_type = _get_gradient_op_type(node_def); | |||
if(op_type is not null) | |||
{ | |||
if (new_gradient_op_types.ContainsKey(op_type)) | |||
{ | |||
node_def.Attr["_gradient_op_type"].S = tf.compat.as_bytes(new_gradient_op_types[op_type]); | |||
} | |||
else | |||
{ | |||
contains_unsaved_custom_gradients = true; | |||
} | |||
} | |||
} | |||
if (contains_unsaved_custom_gradients) | |||
{ | |||
// TODO(Rinne): log warnings. | |||
} | |||
fdef.Signature.Name = _clean_function_name(fdef.Signature.Name); | |||
return orig_name; | |||
} | |||
private static string _clean_function_name(string name) | |||
{ | |||
var match = Regex.Match(name, _FUNCTION_WRAPPER_NAME_REGEX); | |||
if(match.Success) | |||
{ | |||
return match.Groups[1].Value; | |||
} | |||
else | |||
{ | |||
return name; | |||
} | |||
} | |||
/// <summary> | |||
/// Return a topologic sort of FunctionDefs in a library. | |||
/// </summary> | |||
/// <param name="library"></param> | |||
/// <param name="function_deps"></param> | |||
private static IEnumerable<FunctionDef> _sort_function_defs(FunctionDefLibrary library, Dictionary<string, IEnumerable<string>> function_deps) | |||
{ | |||
Dictionary<string, IList<string>> edges = new(); | |||
Dictionary<string, int> in_count = new(); | |||
foreach(var item in function_deps) | |||
{ | |||
var fname = item.Key; | |||
var deps = item.Value; | |||
if(deps is null || deps.Count() == 0) | |||
{ | |||
in_count[fname] = 0; | |||
continue; | |||
} | |||
foreach(var dep in deps) | |||
{ | |||
edges.SetDefault(dep, new List<string>()).Add(fname); | |||
if (in_count.ContainsKey(fname)) | |||
{ | |||
in_count[fname]++; | |||
} | |||
else | |||
{ | |||
in_count[fname] = 1; | |||
} | |||
} | |||
} | |||
var ready = new Stack<string>(library.Function. | |||
Where(x => in_count[x.Signature.Name] == 0) | |||
.Select(x => x.Signature.Name).ToList()); | |||
List<string> output = new(); | |||
while(ready.Count > 0) | |||
{ | |||
var node = ready.Pop(); | |||
output.Add(node); | |||
if (!edges.ContainsKey(node)) | |||
{ | |||
continue; | |||
} | |||
foreach(var dest in edges[node]) | |||
{ | |||
in_count[dest] -= 1; | |||
if (in_count[dest] == 0) | |||
{ | |||
ready.Push(dest); | |||
} | |||
} | |||
} | |||
if(output.Count != library.Function.Count) | |||
{ | |||
var failed_to_resolve = in_count.Keys.Except(output); | |||
throw new ValueError($"There is a cyclic dependency between functions. " + | |||
$"Could not resolve ({string.Join(", ", failed_to_resolve)})."); | |||
} | |||
var reverse = library.Function.ToDictionary(x => x.Signature.Name, x => x); | |||
return output.Select(x => reverse[x]); | |||
} | |||
private static IEnumerable<string> _list_function_deps(FunctionDef fdef, IEnumerable<string> library_function_names, IDictionary<ByteString, string> library_gradient_names) | |||
{ | |||
HashSet<string> deps = new HashSet<string>(); | |||
foreach(var node_def in fdef.NodeDef) | |||
{ | |||
var grad_op_type = _get_gradient_op_type(node_def); | |||
if (library_function_names.Contains(node_def.Op)) | |||
{ | |||
deps.Add(node_def.Op); | |||
} | |||
else if(grad_op_type is not null && library_gradient_names.TryGetValue(grad_op_type, out var gradient_name)) | |||
{ | |||
deps.Add(gradient_name); | |||
} | |||
else | |||
{ | |||
foreach(var attr_value in node_def.Attr.Values) | |||
{ | |||
if(attr_value.ValueCase == AttrValue.ValueOneofCase.Func) | |||
{ | |||
deps.Add(attr_value.Func.Name); | |||
} | |||
else if(attr_value.ValueCase == AttrValue.ValueOneofCase.List) | |||
{ | |||
foreach(var fn in attr_value.List.Func) | |||
{ | |||
deps.Add(fn.Name); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
return deps.AsEnumerable(); | |||
} | |||
private static ByteString _get_gradient_op_type(NodeDef node_def) | |||
{ | |||
if(node_def.Attr.ContainsKey("_gradient_op_type") && node_def.Op != "StatefulPartitionedCall" && node_def.Op != "PartitionedCall") | |||
{ | |||
return node_def.Attr["_gradient_op_type"].S; | |||
} | |||
return null; | |||
} | |||
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
IDictionary<string, ConcreteFunction> concrete_functions) | |||
{ | |||
@@ -30,6 +372,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
// TODO(Rinne): set the functiona spec. | |||
concrete_function.AddTograph(); | |||
return concrete_function; | |||
} | |||
@@ -35,6 +35,8 @@ namespace Tensorflow | |||
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
private List<Trackable> _nodes; | |||
private Dictionary<int, Action<object, object, object>> _node_setters; | |||
private Dictionary<string, ConcreteFunction> _concrete_functions; | |||
private HashSet<string> _restored_concrete_functions; | |||
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||
{ | |||
@@ -44,6 +46,9 @@ namespace Tensorflow | |||
_proto = object_graph_proto; | |||
_export_dir = export_dir; | |||
// TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
_concrete_functions = function_deserialization.load_function_def_library( | |||
meta_graph.GraphDef.Library, _proto); | |||
_restored_concrete_functions = new HashSet<string>(); | |||
_checkpoint_options = ckpt_options; | |||
_save_options = save_options; | |||
@@ -464,9 +469,17 @@ namespace Tensorflow | |||
} | |||
} | |||
private void _setup_function_captures() | |||
private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes) | |||
{ | |||
// TODO: implement it with concrete functions. | |||
if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
{ | |||
return; | |||
} | |||
_restored_concrete_functions.Add(concrete_function_name); | |||
var concrete_function = _concrete_functions[concrete_function_name]; | |||
var proto = _proto.ConcreteFunctions[concrete_function_name]; | |||
var inputs = proto.BoundInputs.Select(x => nodes[x]); | |||
function_saved_model_utils.restore_captures(concrete_function, inputs); | |||
} | |||
private void _setup_remaining_functions() | |||
@@ -625,7 +638,7 @@ namespace Tensorflow | |||
var fn = function_deserialization.recreate_function(proto, null); | |||
foreach (var name in proto.ConcreteFunctions) | |||
{ | |||
_setup_function_captures(); | |||
_setup_function_captures(name, dependencies); | |||
} | |||
return (fn, setattr); | |||
} | |||
@@ -633,8 +646,9 @@ namespace Tensorflow | |||
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
{ | |||
throw new NotImplementedException(); | |||
//var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
_setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
return (fn, setattr); | |||
} | |||
// TODO: remove this to a common class. | |||
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
//public class nested_structure_coder | |||
//{ | |||
// public static List<object> decode_proto(StructuredValue proto) | |||
// { | |||
// return proto s | |||
// } | |||
//} | |||
} |
@@ -572,6 +572,11 @@ namespace Tensorflow | |||
return get_default_graph().building_function; | |||
} | |||
public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public static void dismantle_graph(Graph graph) | |||
{ | |||
@@ -8,9 +8,14 @@ namespace Tensorflow.Keras.Saving | |||
{ | |||
public class KerasMetaData | |||
{ | |||
[JsonProperty("name")] | |||
public string Name { get; set; } | |||
[JsonProperty("class_name")] | |||
public string ClassName { get; set; } | |||
[JsonProperty("trainable")] | |||
public bool Trainable { get; set; } | |||
[JsonProperty("dtype")] | |||
public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||
[JsonProperty("is_graph_network")] | |||
public bool IsGraphNetwork { get; set; } | |||
[JsonProperty("shared_object_id")] | |||
@@ -20,5 +25,13 @@ namespace Tensorflow.Keras.Saving | |||
public JObject Config { get; set; } | |||
[JsonProperty("build_input_shape")] | |||
public TensorShapeConfig BuildInputShape { get; set; } | |||
[JsonProperty("batch_input_shape")] | |||
public TensorShapeConfig BatchInputShape { get; set; } | |||
[JsonProperty("activity_regularizer")] | |||
public IRegularizer ActivityRegularizer { get; set; } | |||
[JsonProperty("input_spec")] | |||
public JToken InputSpec { get; set; } | |||
[JsonProperty("stateful")] | |||
public bool? Stateful { get; set; } | |||
} | |||
} |
@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving | |||
{ | |||
public class KerasObjectLoader | |||
{ | |||
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
internal static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
private SavedMetadata _metadata; | |||
private SavedObjectGraph _proto; | |||
private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
@@ -311,6 +311,10 @@ namespace Tensorflow.Keras.Saving | |||
{ | |||
(obj, setter) = _revive_custom_object(identifier, metadata); | |||
} | |||
if(obj is null) | |||
{ | |||
throw new ValueError($"Cannot revive {metadata.Name} from the config or customized object."); | |||
} | |||
Debug.Assert(obj is Layer); | |||
_maybe_add_serialized_attributes(obj as Layer, metadata); | |||
return (obj, setter); | |||
@@ -349,8 +353,14 @@ namespace Tensorflow.Keras.Saving | |||
private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
{ | |||
return RevivedLayer.init_from_metadata(metadata); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||
@@ -403,9 +413,13 @@ namespace Tensorflow.Keras.Saving | |||
var obj = generic_utils.deserialize_keras_object(class_name, config); | |||
if(obj is null) | |||
{ | |||
return null; | |||
} | |||
obj.Name = metadata.Name; | |||
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||
var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
if (!built) | |||
@@ -0,0 +1,62 @@ | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
using System.Text.RegularExpressions; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
internal static class ReviveUtils | |||
{ | |||
public static T recursively_deserialize_keras_object<T>(JToken config) | |||
{ | |||
throw new NotImplementedException(); | |||
if(config is JObject jobject) | |||
{ | |||
if (jobject.ContainsKey("class_name")) | |||
{ | |||
} | |||
} | |||
} | |||
public static void _revive_setter(object layer, object name, object value) | |||
{ | |||
Debug.Assert(name is string); | |||
Debug.Assert(layer is Layer); | |||
if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
{ | |||
if (value is Trackable trackable) | |||
{ | |||
(layer as Layer)._track_trackable(trackable, name as string); | |||
} | |||
(layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); | |||
} | |||
else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
{ | |||
Debug.Assert(value is Trackable); | |||
functional._track_trackable(value as Trackable, name as string); | |||
} | |||
else | |||
{ | |||
var properties = layer.GetType().GetProperties(); | |||
foreach (var p in properties) | |||
{ | |||
if ((string)name == p.Name) | |||
{ | |||
if(p.GetValue(layer) is not null) | |||
{ | |||
return; | |||
} | |||
p.SetValue(layer, value); | |||
return; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,37 @@ | |||
using Newtonsoft.Json; | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
[JsonConverter(typeof(CustomizedRevivedConfigJsonConverter))] | |||
public class RevivedConfig: IKerasConfig | |||
{ | |||
public JObject Config { get; set; } | |||
} | |||
public class CustomizedRevivedConfigJsonConverter : JsonConverter | |||
{ | |||
public override bool CanConvert(Type objectType) | |||
{ | |||
return objectType == typeof(RevivedConfig); | |||
} | |||
public override bool CanRead => true; | |||
public override bool CanWrite => true; | |||
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
{ | |||
((RevivedConfig)value).Config.WriteTo(writer); | |||
} | |||
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
{ | |||
var config = (JObject)serializer.Deserialize(reader, typeof(JObject)); | |||
return new RevivedConfig() { Config = config }; | |||
} | |||
} | |||
} |
@@ -0,0 +1,73 @@ | |||
using Newtonsoft.Json.Linq; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Keras.Saving.SavedModel; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
public class RevivedLayer: Layer | |||
{ | |||
public static (RevivedLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
{ | |||
LayerArgs args = new LayerArgs() | |||
{ | |||
Name = metadata.Name, | |||
Trainable = metadata.Trainable | |||
}; | |||
if(metadata.DType != TF_DataType.DtInvalid) | |||
{ | |||
args.DType = metadata.DType; | |||
} | |||
if(metadata.BatchInputShape is not null) | |||
{ | |||
args.BatchInputShape = metadata.BatchInputShape; | |||
} | |||
RevivedLayer revived_obj = new RevivedLayer(args); | |||
// TODO(Rinne): implement `expects_training_arg`. | |||
var config = metadata.Config; | |||
if (generic_utils.validate_config(config)) | |||
{ | |||
revived_obj._config = new RevivedConfig() { Config = config }; | |||
} | |||
if(metadata.InputSpec is not null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
if(metadata.ActivityRegularizer is not null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
// TODO(Rinne): `_is_feature_layer` | |||
if(metadata.Stateful is not null) | |||
{ | |||
revived_obj.stateful = metadata.Stateful.Value; | |||
} | |||
return (revived_obj, ReviveUtils._revive_setter); | |||
} | |||
private RevivedConfig _config = null; | |||
public RevivedLayer(LayerArgs args): base(args) | |||
{ | |||
} | |||
public override string ToString() | |||
{ | |||
return $"Customized keras layer: {Name}."; | |||
} | |||
public override IKerasConfig get_config() | |||
{ | |||
return _config; | |||
} | |||
} | |||
} |
@@ -23,6 +23,7 @@ using System.Data; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using System.Reflection; | |||
using System.Security.AccessControl; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
@@ -60,6 +61,10 @@ namespace Tensorflow.Keras.Utils | |||
public static Layer deserialize_keras_object(string class_name, JToken config) | |||
{ | |||
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
if(argType is null) | |||
{ | |||
return null; | |||
} | |||
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
@@ -72,6 +77,10 @@ namespace Tensorflow.Keras.Utils | |||
public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||
{ | |||
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
if (layer is null) | |||
{ | |||
return null; | |||
} | |||
Debug.Assert(layer is Layer); | |||
return layer as Layer; | |||
} | |||
@@ -1,10 +1,12 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.UnitTest.Helpers; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
@@ -56,4 +58,11 @@ public class SequentialModelLoad | |||
model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
} | |||
[TestMethod] | |||
public void Temp() | |||
{ | |||
var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||
model.summary(); | |||
} | |||
} |