From 6d1b45993d32dc16616e822e44a26cdd80963354 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 9 Jan 2021 10:37:42 -0600 Subject: [PATCH] Fix resize_nearest_neighbor_grad. --- .../Functions/ConcreteFunction.cs | 58 +++++++++++----- .../Functions/EagerDefinedFunction.cs | 8 ++- .../Functions/TapeGradientFunctions.cs | 53 ++++++++++++--- .../Gradients/array_grad.cs | 11 ++-- .../Gradients/image_grad.cs | 2 +- .../Graphs/AutoGraphAttribute.cs | 52 ++++++++++++--- .../Operations/array_ops.cs | 23 ++++++- .../Operations/gen_image_ops.cs | 38 +++++++---- .../Operations/gen_random_ops.cs | 10 +++ .../Tensorflow.Binding.csproj | 1 + src/TensorFlowNET.Core/Tensors/Tensors.cs | 3 + src/TensorFlowNET.Core/tensorflow.cs | 2 +- src/TensorFlowNET.Keras/BackendImpl.cs | 3 +- .../Engine/Interfaces/ITensorFlowOpLayer.cs | 12 ++++ src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 10 +-- .../Engine/TensorFlowOpLayer.cs | 66 ------------------- src/TensorFlowNET.Keras/KerasInterface.cs | 5 +- .../Tensorflow.Keras.csproj | 3 +- .../Utils/base_layer_utils.cs | 17 ++++- 19 files changed, 241 insertions(+), 136 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs delete mode 100644 src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index b1d932b6..10bac1fc 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -14,6 +14,7 @@ namespace Tensorflow.Functions { IntPtr _handle; FuncGraph func_graph; + public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; public string Name @@ -127,30 +128,53 @@ namespace Tensorflow.Functions func_graph.Exit(); } - public Tensors Invoke(Tensors inputs) + public Tensors FilteredCall(Tensors inputs) { - var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); - var (forward_function, args_with_tangents) = forward_backward.Forward(); - Tensors flat_outputs = null; - if (tf.Context.executing_eagerly()) - flat_outputs = forward_function.Call(args_with_tangents); - forward_backward.Record(flat_outputs); - return flat_outputs; + return CallFlat(inputs, CapturedInputs); } + /// + /// Executes the wrapped function. + /// + /// + /// + /// public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) { - var new_args = new List(); - new_args.AddRange(args); - new_args.AddRange(captured_inputs); - args = new_args.ToArray(); + var executing_eagerly = tf.Context.executing_eagerly(); + var default_graph = ops.get_default_graph(); + var tensor_inputs = new Tensors(); + foreach (var (i, arg) in enumerate(args)) + { + tensor_inputs.Add(arg); + // If we're graph building, shape inference is on. + if (!executing_eagerly) + { + } + } + tensor_inputs.AddRange(captured_inputs); + + args = tensor_inputs.ToArray(); - var attrs = new object[] + var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; + // No tape is watching; skip to running the function. + if (possible_gradient_type == 0 && executing_eagerly) { - "executor_type", "", - "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() - }; - return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); + var attrs = new object[] + { + "executor_type", "", + "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + }; + return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); + } + + var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); + var (forward_function, args_with_tangents) = forward_backward.Forward(); + Tensors flat_outputs = null; + if (executing_eagerly) + flat_outputs = forward_function.Call(args_with_tangents); + forward_backward.Record(flat_outputs); + return flat_outputs; } ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index f615f6a4..bfb8aa71 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -31,11 +31,17 @@ namespace Tensorflow.Functions public Tensors Call(Tensors args) { + var attrs = new object[] + { + "executor_type", "", + "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + }; + var results = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, _func_graph.FuncName, args, - null, + attrs, _num_outputs); return results; diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 13c57e86..78f8e794 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -49,24 +49,61 @@ namespace Tensorflow.Functions getBackwardFunction: () => backward_function); } + /// + /// Create a backward function given `outputs` from the forward function. + /// + /// + /// + /// + /// (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) { - BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => + var capture_mapping = new Dictionary(); + foreach(var (i, output) in enumerate(outputs)) + capture_mapping[forward_graph.Outputs[i].Id] = output; + + var remapped_captures = new Tensors(); + foreach(var capture in backward.CapturedInputs) + { + if (capture_mapping.ContainsKey(capture.Id)) + remapped_captures.Add(capture_mapping[capture.Id]); + } + + var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; + var recorded_outputs = new Tensors(); + var relevant_outputs = outputs; + var trainable_recorded_outputs = 0; + var skip_positions = new List(); + foreach (var (output_index, output) in enumerate(relevant_outputs)) + { + if (trainable_recorded_outputs < backward_function_inputs) + recorded_outputs.Add(output); + if (gradients_util.IsTrainable(output)) + trainable_recorded_outputs += 1; + else + skip_positions.Add(output_index); + } + + BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => { - var processed_args = new List(); + var processed_args = new Tensors(); var input_index = 0; - foreach (var (output_index, arg) in enumerate(output_grads)) + foreach (var (output_index, arg) in enumerate(args)) { - if (arg is null) + if (skip_positions.Contains(output_index)) + continue; + if (arg == null) throw new NotImplementedException(""); - processed_args.add(arg); + processed_args.Add(arg); input_index += 1; + if (input_index >= backward_function_inputs) + break; } tf.Logger.Debug($"Invoke backward function: {backward.Name}"); - return backward.CallFlat(processed_args.ToArray(), outputs); + return backward.CallFlat(processed_args, remapped_captures); }; - return (_backward_function_wrapper, outputs); + return (_backward_function_wrapper, recorded_outputs); } protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int) @@ -103,7 +140,7 @@ namespace Tensorflow.Functions } backwards_graph.Exit(); - var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; + var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; var backward_function_attr = new Dictionary(); backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; gradients_wrt_outputs.append(backwards_graph.internal_captures); diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 0fe61bd2..db18d25c 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -228,13 +228,14 @@ namespace Tensorflow.Gradients var grad = grads[0]; var x = op.inputs[0]; var a = op.inputs[1]; - var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); - var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); + var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) }); + var begin = constant_op.constant(new[] { 0, 0 }); + var pad_before = array_ops.slice(a, begin, size); // Make it a 1-D tensor. - var begin = array_ops.reshape(pad_before, new[] { -1 }); - var sizes = array_ops.shape(x); - var x_grad = array_ops.slice(grad, begin, sizes); + begin = array_ops.reshape(pad_before, new[] { -1 }); + size = array_ops.shape(x); + var x_grad = array_ops.slice(grad, begin, size); if (len(op.inputs) == 3) return new Tensor[] { x_grad, null, null }; diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs index 87a67d64..ccc70fea 100644 --- a/src/TensorFlowNET.Core/Gradients/image_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Gradients var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); Tensor image_shape = null; if (shape.is_fully_defined()) - throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); + image_shape = constant_op.constant(image.shape[1..3]); else image_shape = array_ops.shape(image)["1:3"]; diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 010eb345..9ffc7ea0 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -8,6 +8,9 @@ using static Tensorflow.Binding; namespace Tensorflow.Graphs { + /// + /// func_graph.py func_graph_from_py_func + /// [AllowChangingInputArguments] public sealed class AutoGraphAttribute : OnMethodBoundaryAspect { @@ -18,15 +21,16 @@ namespace Tensorflow.Graphs public override void OnEntry(MethodExecutionArgs args) { - func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; + // TODO: func_name can be cache in FullName + Args + func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}"; if (functions.ContainsKey(func_name)) { function = functions[func_name]; if (args.Arguments[0] is Tensors tensor_inputs) - args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); + args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); else - args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); + args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); args.FlowBehavior = FlowBehavior.Return; return; } @@ -62,14 +66,27 @@ namespace Tensorflow.Graphs { if (args.ReturnValue is Tensors outputs) { - if (args.Arguments[0] is Tensors inputs) - function.ToGraph(inputs, outputs); + Tensors inputs = null; + outputs = mark_as_return(outputs); + if (args.Arguments[0] is Tensors inputs1) + inputs = inputs1; else - function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); + inputs = args.Arguments.Select(x => x as Tensor).ToArray(); + + inputs = inputs.Where(x => x.op.OpType == "Placeholder" + && x.op.name.StartsWith("inputs")).ToArray(); + + function.ToGraph(inputs, outputs); } - else - function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); - + else if (args.ReturnValue is Tensor output) + { + var inputs = args.Arguments.Select(x => x as Tensor) + .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) + .ToArray(); + var outputs2 = array_ops.identity(output); + function.ToGraph(inputs, outputs2); + } + function.Exit(); // cache function. @@ -77,7 +94,7 @@ namespace Tensorflow.Graphs functions[func_name] = function; // run function - args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); + args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); } object ConvertReturnValue(Tensors tensors) @@ -87,5 +104,20 @@ namespace Tensorflow.Graphs else return tensors; } + + /// + /// Acts like identity but marks the `Tensor` as a return value. + /// + /// + /// + public Tensors mark_as_return(Tensors tensors) + { + if (tensors == null) + return null; + var result = new Tensors(); + foreach (var tensor in tensors) + result.Add(array_ops.identity(tensor)); + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 1de84664..34670070 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -925,7 +925,28 @@ namespace Tensorflow public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); - public static Tensor stack(object values, int axis = 0, string name = "stack") + public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Slice", name, new + { + input, begin, size + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Slice", name, + null, + input, begin, size).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "Index", op.get_attr("Index") + }; + tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); + }, + new Tensors(input, begin, size)); + + public static Tensor stack(object values, int axis = 0, string name = "stack") { if (axis == 0) // If the input is a constant list, it can be converted to a constant op diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs index 21045d75..87bc12ee 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs @@ -238,18 +238,32 @@ namespace Tensorflow "half_pixel_centers", half_pixel_centers).FirstOrDefault(), images); - public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tsize size, bool align_corners = false, + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - { - var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new - { - grads, - size, - align_corners, - half_pixel_centers - }); - - return op.output; - } + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new + { + grads, + size, + align_corners, + half_pixel_centers + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ResizeNearestNeighborGrad", name, + null, + grads, size, + "align_corners", align_corners, + "half_pixel_centers", half_pixel_centers).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "align_corners", op.get_attr("align_corners"), + "half_pixel_centers", op.get_attr("half_pixel_centers") + }; + tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs); + }, + new Tensors(grads, size)); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index dda9dfa0..8528f4c4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -126,6 +126,16 @@ namespace Tensorflow public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "RandomShuffle", name, + null, + value, seed, seed2); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", name: name, args: new { value, seed, seed2 }); diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 8d3645b2..d0d445a2 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including: + diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 82f51f58..8e0315ef 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -57,6 +57,9 @@ namespace Tensorflow public void Add(Tensor tensor) => items.Add(tensor); + public void AddRange(Tensor[] tensors) + => items.AddRange(tensors); + IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 60b22f71..f7349dc5 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -48,7 +48,7 @@ namespace Tensorflow public tensorflow() { Logger = new LoggerConfiguration() - .MinimumLevel.Error() + .MinimumLevel.Warning() .WriteTo.Console() .CreateLogger(); diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 4c25f70a..ee3eaead 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -16,6 +16,7 @@ using NumSharp; using System; +using System.Linq; using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Graphs; @@ -197,7 +198,7 @@ namespace Tensorflow.Keras } if (outputs[0].op.type == "Placeholder" || outputs[0].op.type == "StridedSlice") - return exec_graph.external_captures[0].numpy(); + return exec_graph.external_captures.Last().numpy(); // Consolidate updates exec_graph.as_default(); diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs new file mode 100644 index 00000000..a1d3ecbf --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine +{ + public interface ITensorFlowOpLayer + { + Layer GetOpLayer(TensorFlowOpLayerArgs args); + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 3699f6af..8c395281 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -1,5 +1,4 @@ using NumSharp; -using ShellProgressBar; using System; using System.Collections.Generic; using System.Linq; @@ -88,15 +87,8 @@ namespace Tensorflow.Keras.Engine { stop_training = false; _train_counter.assign(0); - var options = new ProgressBarOptions - { - ProgressCharacter = '.', - ProgressBarOnBottom = true - }; - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { - using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); // reset_metrics(); // callbacks.on_epoch_begin(epoch) // data_handler.catch_stop_iteration(); @@ -105,7 +97,7 @@ namespace Tensorflow.Keras.Engine // callbacks.on_train_batch_begin(step) var results = step_function(iterator); var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); - pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); + Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); } } } diff --git a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs deleted file mode 100644 index d0bf36e6..00000000 --- a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs +++ /dev/null @@ -1,66 +0,0 @@ -using NumSharp; -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Graphs; -using Tensorflow.Keras.ArgsDefinition; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras.Engine -{ - public class TensorFlowOpLayer : Layer - { - TensorFlowOpLayerArgs args; - Dictionary constants => args.Constants; - NodeDef node_def => args.NodeDef; - static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; - public string OpType => node_def.Op; - - public TensorFlowOpLayer(TensorFlowOpLayerArgs args) - : base(new LayerArgs - { - Name = TF_OP_LAYER_NAME_PREFIX + args.Name, - Trainable = args.Trainable, - DType = args.DType, - Autocast = false - }) - { - this.args = args; - built = true; - } - - protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) - { - if (tf.Context.executing_eagerly()) - return _defun_call(inputs); - return MakOp(inputs); - } - - [AutoGraph] - Tensors _defun_call(Tensors inputs) - => MakOp(inputs); - - Tensors MakOp(Tensors inputs) - { - foreach (var (index, constant) in enumerate(constants)) - { - var value = constant_op.constant(constant, name: node_def.Input[index]); - var new_inputs = inputs.ToList(); - new_inputs.Insert(index, value); - inputs = new Tensors(new_inputs.ToArray()); - } - - var graph = inputs.graph; - var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); - var op = graph._create_op_from_tf_operation(c_op); - op._control_flow_post_processing(); - - // Record the gradient because custom-made ops don't go through the - // code-gen'd eager call path - var op_type = op.node_def.Op; - - tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); - - return op.outputs; - } - } -} diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 40519ac4..50f80b6d 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -1,4 +1,7 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Datasets; using Tensorflow.Keras.Engine; diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 14c5719d..fa3d32ba 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 @@ -47,7 +47,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac - diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index fe93e584..ecba3473 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -18,6 +18,7 @@ using NumSharp; using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -151,7 +152,7 @@ namespace Tensorflow.Keras.Utils // recursively CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs + var op_layer = GetLayer(new TensorFlowOpLayerArgs { NodeDef = op.node_def, Constants = constants, @@ -164,6 +165,20 @@ namespace Tensorflow.Keras.Utils } } + static Layer GetLayer(LayerArgs args) + { + Layer layer = default; + var assemble = Assembly.Load("TensorFlow.Keras.Layers"); + foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) + { + layer = (Layer)Activator.CreateInstance(type, new object[] { args }); + } + + if (layer == null) + throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); + return layer; + } + // recusive static bool uses_keras_history(Tensor op_input) {