@@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||
{ | |||
IntPtr _handle; | |||
FuncGraph func_graph; | |||
public Tensor[] Inputs => func_graph.Inputs; | |||
public Tensor[] CapturedInputs => func_graph.external_captures; | |||
public string Name | |||
@@ -127,30 +128,53 @@ namespace Tensorflow.Functions | |||
func_graph.Exit(); | |||
} | |||
public Tensors Invoke(Tensors inputs) | |||
public Tensors FilteredCall(Tensors inputs) | |||
{ | |||
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
Tensors flat_outputs = null; | |||
if (tf.Context.executing_eagerly()) | |||
flat_outputs = forward_function.Call(args_with_tangents); | |||
forward_backward.Record(flat_outputs); | |||
return flat_outputs; | |||
return CallFlat(inputs, CapturedInputs); | |||
} | |||
/// <summary> | |||
/// Executes the wrapped function. | |||
/// </summary> | |||
/// <param name="args"></param> | |||
/// <param name="captured_inputs"></param> | |||
/// <returns></returns> | |||
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | |||
{ | |||
var new_args = new List<Tensor>(); | |||
new_args.AddRange(args); | |||
new_args.AddRange(captured_inputs); | |||
args = new_args.ToArray(); | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
var default_graph = ops.get_default_graph(); | |||
var tensor_inputs = new Tensors(); | |||
foreach (var (i, arg) in enumerate(args)) | |||
{ | |||
tensor_inputs.Add(arg); | |||
// If we're graph building, shape inference is on. | |||
if (!executing_eagerly) | |||
{ | |||
} | |||
} | |||
tensor_inputs.AddRange(captured_inputs); | |||
args = tensor_inputs.ToArray(); | |||
var attrs = new object[] | |||
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||
// No tape is watching; skip to running the function. | |||
if (possible_gradient_type == 0 && executing_eagerly) | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
}; | |||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
var attrs = new object[] | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
}; | |||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
} | |||
var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
Tensors flat_outputs = null; | |||
if (executing_eagerly) | |||
flat_outputs = forward_function.Call(args_with_tangents); | |||
forward_backward.Record(flat_outputs); | |||
return flat_outputs; | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
@@ -31,11 +31,17 @@ namespace Tensorflow.Functions | |||
public Tensors Call(Tensors args) | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
}; | |||
var results = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
_func_graph.FuncName, | |||
args, | |||
null, | |||
attrs, | |||
_num_outputs); | |||
return results; | |||
@@ -49,24 +49,61 @@ namespace Tensorflow.Functions | |||
getBackwardFunction: () => backward_function); | |||
} | |||
/// <summary> | |||
/// Create a backward function given `outputs` from the forward function. | |||
/// </summary> | |||
/// <param name="forward_graph"></param> | |||
/// <param name="backward"></param> | |||
/// <param name="outputs"></param> | |||
/// <returns></returns> | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
{ | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
var capture_mapping = new Dictionary<long, Tensor>(); | |||
foreach(var (i, output) in enumerate(outputs)) | |||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
var remapped_captures = new Tensors(); | |||
foreach(var capture in backward.CapturedInputs) | |||
{ | |||
if (capture_mapping.ContainsKey(capture.Id)) | |||
remapped_captures.Add(capture_mapping[capture.Id]); | |||
} | |||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
var recorded_outputs = new Tensors(); | |||
var relevant_outputs = outputs; | |||
var trainable_recorded_outputs = 0; | |||
var skip_positions = new List<int>(); | |||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
{ | |||
if (trainable_recorded_outputs < backward_function_inputs) | |||
recorded_outputs.Add(output); | |||
if (gradients_util.IsTrainable(output)) | |||
trainable_recorded_outputs += 1; | |||
else | |||
skip_positions.Add(output_index); | |||
} | |||
BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => | |||
{ | |||
var processed_args = new List<Tensor>(); | |||
var processed_args = new Tensors(); | |||
var input_index = 0; | |||
foreach (var (output_index, arg) in enumerate(output_grads)) | |||
foreach (var (output_index, arg) in enumerate(args)) | |||
{ | |||
if (arg is null) | |||
if (skip_positions.Contains(output_index)) | |||
continue; | |||
if (arg == null) | |||
throw new NotImplementedException(""); | |||
processed_args.add(arg); | |||
processed_args.Add(arg); | |||
input_index += 1; | |||
if (input_index >= backward_function_inputs) | |||
break; | |||
} | |||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
return backward.CallFlat(processed_args.ToArray(), outputs); | |||
return backward.CallFlat(processed_args, remapped_captures); | |||
}; | |||
return (_backward_function_wrapper, outputs); | |||
return (_backward_function_wrapper, recorded_outputs); | |||
} | |||
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
@@ -103,7 +140,7 @@ namespace Tensorflow.Functions | |||
} | |||
backwards_graph.Exit(); | |||
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
var backward_function_attr = new Dictionary<string, string>(); | |||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
@@ -228,13 +228,14 @@ namespace Tensorflow.Gradients | |||
var grad = grads[0]; | |||
var x = op.inputs[0]; | |||
var a = op.inputs[1]; | |||
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||
var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) }); | |||
var begin = constant_op.constant(new[] { 0, 0 }); | |||
var pad_before = array_ops.slice(a, begin, size); | |||
// Make it a 1-D tensor. | |||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||
var sizes = array_ops.shape(x); | |||
var x_grad = array_ops.slice(grad, begin, sizes); | |||
begin = array_ops.reshape(pad_before, new[] { -1 }); | |||
size = array_ops.shape(x); | |||
var x_grad = array_ops.slice(grad, begin, size); | |||
if (len(op.inputs) == 3) | |||
return new Tensor[] { x_grad, null, null }; | |||
@@ -30,7 +30,7 @@ namespace Tensorflow.Gradients | |||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||
Tensor image_shape = null; | |||
if (shape.is_fully_defined()) | |||
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||
image_shape = constant_op.constant(image.shape[1..3]); | |||
else | |||
image_shape = array_ops.shape(image)["1:3"]; | |||
@@ -8,6 +8,9 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
{ | |||
/// <summary> | |||
/// func_graph.py func_graph_from_py_func | |||
/// </summary> | |||
[AllowChangingInputArguments] | |||
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | |||
{ | |||
@@ -18,15 +21,16 @@ namespace Tensorflow.Graphs | |||
public override void OnEntry(MethodExecutionArgs args) | |||
{ | |||
func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; | |||
// TODO: func_name can be cache in FullName + Args | |||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}"; | |||
if (functions.ContainsKey(func_name)) | |||
{ | |||
function = functions[func_name]; | |||
if (args.Arguments[0] is Tensors tensor_inputs) | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); | |||
else | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
args.FlowBehavior = FlowBehavior.Return; | |||
return; | |||
} | |||
@@ -62,14 +66,27 @@ namespace Tensorflow.Graphs | |||
{ | |||
if (args.ReturnValue is Tensors outputs) | |||
{ | |||
if (args.Arguments[0] is Tensors inputs) | |||
function.ToGraph(inputs, outputs); | |||
Tensors inputs = null; | |||
outputs = mark_as_return(outputs); | |||
if (args.Arguments[0] is Tensors inputs1) | |||
inputs = inputs1; | |||
else | |||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||
inputs = args.Arguments.Select(x => x as Tensor).ToArray(); | |||
inputs = inputs.Where(x => x.op.OpType == "Placeholder" | |||
&& x.op.name.StartsWith("inputs")).ToArray(); | |||
function.ToGraph(inputs, outputs); | |||
} | |||
else | |||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||
else if (args.ReturnValue is Tensor output) | |||
{ | |||
var inputs = args.Arguments.Select(x => x as Tensor) | |||
.Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) | |||
.ToArray(); | |||
var outputs2 = array_ops.identity(output); | |||
function.ToGraph(inputs, outputs2); | |||
} | |||
function.Exit(); | |||
// cache function. | |||
@@ -77,7 +94,7 @@ namespace Tensorflow.Graphs | |||
functions[func_name] = function; | |||
// run function | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); | |||
} | |||
object ConvertReturnValue(Tensors tensors) | |||
@@ -87,5 +104,20 @@ namespace Tensorflow.Graphs | |||
else | |||
return tensors; | |||
} | |||
/// <summary> | |||
/// Acts like identity but marks the `Tensor` as a return value. | |||
/// </summary> | |||
/// <param name="tensors"></param> | |||
/// <returns></returns> | |||
public Tensors mark_as_return(Tensors tensors) | |||
{ | |||
if (tensors == null) | |||
return null; | |||
var result = new Tensors(); | |||
foreach (var tensor in tensors) | |||
result.Add(array_ops.identity(tensor)); | |||
return result; | |||
} | |||
} | |||
} |
@@ -925,7 +925,28 @@ namespace Tensorflow | |||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
public static Tensor stack(object values, int axis = 0, string name = "stack") | |||
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||
=> tf.Context.RunInAutoMode2( | |||
() => tf.OpDefLib._apply_op_helper("Slice", name, new | |||
{ | |||
input, begin, size | |||
}).output, | |||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Slice", name, | |||
null, | |||
input, begin, size).FirstOrDefault(), | |||
(op) => | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"T", op.get_attr<TF_DataType>("T"), | |||
"Index", op.get_attr<int>("Index") | |||
}; | |||
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); | |||
}, | |||
new Tensors(input, begin, size)); | |||
public static Tensor stack(object values, int axis = 0, string name = "stack") | |||
{ | |||
if (axis == 0) | |||
// If the input is a constant list, it can be converted to a constant op | |||
@@ -238,18 +238,32 @@ namespace Tensorflow | |||
"half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
images); | |||
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, | |||
bool half_pixel_centers = false, string name = null) | |||
{ | |||
var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||
{ | |||
grads, | |||
size, | |||
align_corners, | |||
half_pixel_centers | |||
}); | |||
return op.output; | |||
} | |||
=> tf.Context.RunInAutoMode2( | |||
() => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new | |||
{ | |||
grads, | |||
size, | |||
align_corners, | |||
half_pixel_centers | |||
}).output, | |||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResizeNearestNeighborGrad", name, | |||
null, | |||
grads, size, | |||
"align_corners", align_corners, | |||
"half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
(op) => | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"T", op.get_attr<TF_DataType>("T"), | |||
"align_corners", op.get_attr<bool>("align_corners"), | |||
"half_pixel_centers", op.get_attr<bool>("half_pixel_centers") | |||
}; | |||
tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs); | |||
}, | |||
new Tensors(grads, size)); | |||
} | |||
} |
@@ -126,6 +126,16 @@ namespace Tensorflow | |||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||
string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"RandomShuffle", name, | |||
null, | |||
value, seed, seed2); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", | |||
name: name, | |||
args: new { value, seed, seed2 }); | |||
@@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including: | |||
<ItemGroup> | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | |||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" /> | |||
<PackageReference Include="NumSharp.Lite" Version="0.1.10" /> | |||
<PackageReference Include="Protobuf.Text" Version="0.4.0" /> | |||
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | |||
@@ -57,6 +57,9 @@ namespace Tensorflow | |||
public void Add(Tensor tensor) | |||
=> items.Add(tensor); | |||
public void AddRange(Tensor[] tensors) | |||
=> items.AddRange(tensors); | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
throw new NotImplementedException(); | |||
@@ -48,7 +48,7 @@ namespace Tensorflow | |||
public tensorflow() | |||
{ | |||
Logger = new LoggerConfiguration() | |||
.MinimumLevel.Error() | |||
.MinimumLevel.Warning() | |||
.WriteTo.Console() | |||
.CreateLogger(); | |||
@@ -16,6 +16,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Linq; | |||
using System.Collections.Generic; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Graphs; | |||
@@ -197,7 +198,7 @@ namespace Tensorflow.Keras | |||
} | |||
if (outputs[0].op.type == "Placeholder" | |||
|| outputs[0].op.type == "StridedSlice") | |||
return exec_graph.external_captures[0].numpy(); | |||
return exec_graph.external_captures.Last().numpy(); | |||
// Consolidate updates | |||
exec_graph.as_default(); | |||
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public interface ITensorFlowOpLayer | |||
{ | |||
Layer GetOpLayer(TensorFlowOpLayerArgs args); | |||
} | |||
} |
@@ -1,5 +1,4 @@ | |||
using NumSharp; | |||
using ShellProgressBar; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
@@ -88,15 +87,8 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
var options = new ProgressBarOptions | |||
{ | |||
ProgressCharacter = '.', | |||
ProgressBarOnBottom = true | |||
}; | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); | |||
// reset_metrics(); | |||
// callbacks.on_epoch_begin(epoch) | |||
// data_handler.catch_stop_iteration(); | |||
@@ -105,7 +97,7 @@ namespace Tensorflow.Keras.Engine | |||
// callbacks.on_train_batch_begin(step) | |||
var results = step_function(iterator); | |||
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | |||
pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||
} | |||
} | |||
} | |||
@@ -1,66 +0,0 @@ | |||
using NumSharp; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class TensorFlowOpLayer : Layer | |||
{ | |||
TensorFlowOpLayerArgs args; | |||
Dictionary<int, NDArray> constants => args.Constants; | |||
NodeDef node_def => args.NodeDef; | |||
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; | |||
public string OpType => node_def.Op; | |||
public TensorFlowOpLayer(TensorFlowOpLayerArgs args) | |||
: base(new LayerArgs | |||
{ | |||
Name = TF_OP_LAYER_NAME_PREFIX + args.Name, | |||
Trainable = args.Trainable, | |||
DType = args.DType, | |||
Autocast = false | |||
}) | |||
{ | |||
this.args = args; | |||
built = true; | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
return _defun_call(inputs); | |||
return MakOp(inputs); | |||
} | |||
[AutoGraph] | |||
Tensors _defun_call(Tensors inputs) | |||
=> MakOp(inputs); | |||
Tensors MakOp(Tensors inputs) | |||
{ | |||
foreach (var (index, constant) in enumerate(constants)) | |||
{ | |||
var value = constant_op.constant(constant, name: node_def.Input[index]); | |||
var new_inputs = inputs.ToList(); | |||
new_inputs.Insert(index, value); | |||
inputs = new Tensors(new_inputs.ToArray()); | |||
} | |||
var graph = inputs.graph; | |||
var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); | |||
var op = graph._create_op_from_tf_operation(c_op); | |||
op._control_flow_post_processing(); | |||
// Record the gradient because custom-made ops don't go through the | |||
// code-gen'd eager call path | |||
var op_type = op.node_def.Op; | |||
tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||
return op.outputs; | |||
} | |||
} | |||
} |
@@ -1,4 +1,7 @@ | |||
using System.Collections.Generic; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Reflection; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Datasets; | |||
using Tensorflow.Keras.Engine; | |||
@@ -1,4 +1,4 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
@@ -47,7 +47,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | |||
<PackageReference Include="SharpZipLib" Version="1.3.1" /> | |||
<PackageReference Include="ShellProgressBar" Version="5.0.0" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
@@ -18,6 +18,7 @@ using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Reflection; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
@@ -151,7 +152,7 @@ namespace Tensorflow.Keras.Utils | |||
// recursively | |||
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | |||
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||
var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs | |||
{ | |||
NodeDef = op.node_def, | |||
Constants = constants, | |||
@@ -164,6 +165,20 @@ namespace Tensorflow.Keras.Utils | |||
} | |||
} | |||
static Layer GetLayer<T>(LayerArgs args) | |||
{ | |||
Layer layer = default; | |||
var assemble = Assembly.Load("TensorFlow.Keras.Layers"); | |||
foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) | |||
{ | |||
layer = (Layer)Activator.CreateInstance(type, new object[] { args }); | |||
} | |||
if (layer == null) | |||
throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); | |||
return layer; | |||
} | |||
// recusive | |||
static bool uses_keras_history(Tensor op_input) | |||
{ | |||