Browse Source

Fix resize_nearest_neighbor_grad.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
6d1b45993d
19 changed files with 241 additions and 136 deletions
  1. +41
    -17
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  2. +7
    -1
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  3. +45
    -8
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  4. +6
    -5
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Gradients/image_grad.cs
  6. +42
    -10
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  7. +22
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  8. +26
    -12
      src/TensorFlowNET.Core/Operations/gen_image_ops.cs
  9. +10
    -0
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  10. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  11. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  12. +1
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  13. +2
    -1
      src/TensorFlowNET.Keras/BackendImpl.cs
  14. +12
    -0
      src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs
  15. +1
    -9
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  16. +0
    -66
      src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs
  17. +4
    -1
      src/TensorFlowNET.Keras/KerasInterface.cs
  18. +1
    -2
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  19. +16
    -1
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs

+ 41
- 17
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow.Functions
{
IntPtr _handle;
FuncGraph func_graph;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;

public string Name
@@ -127,30 +128,53 @@ namespace Tensorflow.Functions
func_graph.Exit();
}

public Tensors Invoke(Tensors inputs)
public Tensors FilteredCall(Tensors inputs)
{
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (tf.Context.executing_eagerly())
flat_outputs = forward_function.Call(args_with_tangents);
forward_backward.Record(flat_outputs);
return flat_outputs;
return CallFlat(inputs, CapturedInputs);
}

/// <summary>
/// Executes the wrapped function.
/// </summary>
/// <param name="args"></param>
/// <param name="captured_inputs"></param>
/// <returns></returns>
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
{
var new_args = new List<Tensor>();
new_args.AddRange(args);
new_args.AddRange(captured_inputs);
args = new_args.ToArray();
var executing_eagerly = tf.Context.executing_eagerly();
var default_graph = ops.get_default_graph();
var tensor_inputs = new Tensors();
foreach (var (i, arg) in enumerate(args))
{
tensor_inputs.Add(arg);
// If we're graph building, shape inference is on.
if (!executing_eagerly)
{
}
}
tensor_inputs.AddRange(captured_inputs);

args = tensor_inputs.ToArray();

var attrs = new object[]
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0;
// No tape is watching; skip to running the function.
if (possible_gradient_type == 0 && executing_eagerly)
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
}

var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (executing_eagerly)
flat_outputs = forward_function.Call(args_with_tangents);
forward_backward.Record(flat_outputs);
return flat_outputs;
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)


+ 7
- 1
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -31,11 +31,17 @@ namespace Tensorflow.Functions

public Tensors Call(Tensors args)
{
var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};

var results = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
_func_graph.FuncName,
args,
null,
attrs,
_num_outputs);

return results;


+ 45
- 8
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -49,24 +49,61 @@ namespace Tensorflow.Functions
getBackwardFunction: () => backward_function);
}

/// <summary>
/// Create a backward function given `outputs` from the forward function.
/// </summary>
/// <param name="forward_graph"></param>
/// <param name="backward"></param>
/// <param name="outputs"></param>
/// <returns></returns>
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
var capture_mapping = new Dictionary<long, Tensor>();
foreach(var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach(var capture in backward.CapturedInputs)
{
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors();
var relevant_outputs = outputs;
var trainable_recorded_outputs = 0;
var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(relevant_outputs))
{
if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output);
if (gradients_util.IsTrainable(output))
trainable_recorded_outputs += 1;
else
skip_positions.Add(output_index);
}

BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) =>
{
var processed_args = new List<Tensor>();
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(output_grads))
foreach (var (output_index, arg) in enumerate(args))
{
if (arg is null)
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.add(arg);
processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
}
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
return backward.CallFlat(processed_args.ToArray(), outputs);
return backward.CallFlat(processed_args, remapped_captures);
};

return (_backward_function_wrapper, outputs);
return (_backward_function_wrapper, recorded_outputs);
}

protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
@@ -103,7 +140,7 @@ namespace Tensorflow.Functions
}
backwards_graph.Exit();

var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
gradients_wrt_outputs.append(backwards_graph.internal_captures);


+ 6
- 5
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -228,13 +228,14 @@ namespace Tensorflow.Gradients
var grad = grads[0];
var x = op.inputs[0];
var a = op.inputs[1];
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 });
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size);
var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) });
var begin = constant_op.constant(new[] { 0, 0 });
var pad_before = array_ops.slice(a, begin, size);

// Make it a 1-D tensor.
var begin = array_ops.reshape(pad_before, new[] { -1 });
var sizes = array_ops.shape(x);
var x_grad = array_ops.slice(grad, begin, sizes);
begin = array_ops.reshape(pad_before, new[] { -1 });
size = array_ops.shape(x);
var x_grad = array_ops.slice(grad, begin, size);

if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null };


+ 1
- 1
src/TensorFlowNET.Core/Gradients/image_grad.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Gradients
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
Tensor image_shape = null;
if (shape.is_fully_defined())
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined");
image_shape = constant_op.constant(image.shape[1..3]);
else
image_shape = array_ops.shape(image)["1:3"];



+ 42
- 10
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -8,6 +8,9 @@ using static Tensorflow.Binding;

namespace Tensorflow.Graphs
{
/// <summary>
/// func_graph.py func_graph_from_py_func
/// </summary>
[AllowChangingInputArguments]
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
{
@@ -18,15 +21,16 @@ namespace Tensorflow.Graphs

public override void OnEntry(MethodExecutionArgs args)
{
func_name = $"{args.Method.Name}_{Guid.NewGuid()}";
// TODO: func_name can be cache in FullName + Args
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}";

if (functions.ContainsKey(func_name))
{
function = functions[func_name];
if (args.Arguments[0] is Tensors tensor_inputs)
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs));
args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs));
else
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray()));
args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray()));
args.FlowBehavior = FlowBehavior.Return;
return;
}
@@ -62,14 +66,27 @@ namespace Tensorflow.Graphs
{
if (args.ReturnValue is Tensors outputs)
{
if (args.Arguments[0] is Tensors inputs)
function.ToGraph(inputs, outputs);
Tensors inputs = null;
outputs = mark_as_return(outputs);
if (args.Arguments[0] is Tensors inputs1)
inputs = inputs1;
else
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs);
inputs = args.Arguments.Select(x => x as Tensor).ToArray();

inputs = inputs.Where(x => x.op.OpType == "Placeholder"
&& x.op.name.StartsWith("inputs")).ToArray();

function.ToGraph(inputs, outputs);
}
else
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);
else if (args.ReturnValue is Tensor output)
{
var inputs = args.Arguments.Select(x => x as Tensor)
.Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs"))
.ToArray();
var outputs2 = array_ops.identity(output);
function.ToGraph(inputs, outputs2);
}

function.Exit();

// cache function.
@@ -77,7 +94,7 @@ namespace Tensorflow.Graphs
functions[func_name] = function;

// run function
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs));
args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs));
}

object ConvertReturnValue(Tensors tensors)
@@ -87,5 +104,20 @@ namespace Tensorflow.Graphs
else
return tensors;
}

/// <summary>
/// Acts like identity but marks the `Tensor` as a return value.
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public Tensors mark_as_return(Tensors tensors)
{
if (tensors == null)
return null;
var result = new Tensors();
foreach (var tensor in tensors)
result.Add(array_ops.identity(tensor));
return result;
}
}
}

+ 22
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -925,7 +925,28 @@ namespace Tensorflow
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);

public static Tensor stack(object values, int axis = 0, string name = "stack")
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Slice", name, new
{
input, begin, size
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Slice", name,
null,
input, begin, size).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T"),
"Index", op.get_attr<int>("Index")
};
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
},
new Tensors(input, begin, size));
public static Tensor stack(object values, int axis = 0, string name = "stack")
{
if (axis == 0)
// If the input is a constant list, it can be converted to a constant op


+ 26
- 12
src/TensorFlowNET.Core/Operations/gen_image_ops.cs View File

@@ -238,18 +238,32 @@ namespace Tensorflow
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
images);

public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false,
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
{
var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new
{
grads,
size,
align_corners,
half_pixel_centers
});

return op.output;
}
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new
{
grads,
size,
align_corners,
half_pixel_centers
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResizeNearestNeighborGrad", name,
null,
grads, size,
"align_corners", align_corners,
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T"),
"align_corners", op.get_attr<bool>("align_corners"),
"half_pixel_centers", op.get_attr<bool>("half_pixel_centers")
};
tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs);
},
new Tensors(grads, size));
}
}

+ 10
- 0
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -126,6 +126,16 @@ namespace Tensorflow
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomShuffle", name,
null,
value, seed, seed2);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("RandomShuffle",
name: name,
args: new { value, seed, seed2 });


+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including:

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />


+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -57,6 +57,9 @@ namespace Tensorflow
public void Add(Tensor tensor)
=> items.Add(tensor);

public void AddRange(Tensor[] tensors)
=> items.AddRange(tensors);

IEnumerator IEnumerable.GetEnumerator()
{
throw new NotImplementedException();


+ 1
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow
public tensorflow()
{
Logger = new LoggerConfiguration()
.MinimumLevel.Error()
.MinimumLevel.Warning()
.WriteTo.Console()
.CreateLogger();



+ 2
- 1
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -16,6 +16,7 @@

using NumSharp;
using System;
using System.Linq;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
@@ -197,7 +198,7 @@ namespace Tensorflow.Keras
}
if (outputs[0].op.type == "Placeholder"
|| outputs[0].op.type == "StridedSlice")
return exec_graph.external_captures[0].numpy();
return exec_graph.external_captures.Last().numpy();

// Consolidate updates
exec_graph.as_default();


+ 12
- 0
src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine
{
public interface ITensorFlowOpLayer
{
Layer GetOpLayer(TensorFlowOpLayerArgs args);
}
}

+ 1
- 9
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -1,5 +1,4 @@
using NumSharp;
using ShellProgressBar;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -88,15 +87,8 @@ namespace Tensorflow.Keras.Engine
{
stop_training = false;
_train_counter.assign(0);
var options = new ProgressBarOptions
{
ProgressCharacter = '.',
ProgressBarOnBottom = true
};

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options);
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
@@ -105,7 +97,7 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_train_batch_begin(step)
var results = step_function(iterator);
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
}
}
}


+ 0
- 66
src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs View File

@@ -1,66 +0,0 @@
using NumSharp;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public class TensorFlowOpLayer : Layer
{
TensorFlowOpLayerArgs args;
Dictionary<int, NDArray> constants => args.Constants;
NodeDef node_def => args.NodeDef;
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_";
public string OpType => node_def.Op;

public TensorFlowOpLayer(TensorFlowOpLayerArgs args)
: base(new LayerArgs
{
Name = TF_OP_LAYER_NAME_PREFIX + args.Name,
Trainable = args.Trainable,
DType = args.DType,
Autocast = false
})
{
this.args = args;
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (tf.Context.executing_eagerly())
return _defun_call(inputs);
return MakOp(inputs);
}

[AutoGraph]
Tensors _defun_call(Tensors inputs)
=> MakOp(inputs);

Tensors MakOp(Tensors inputs)
{
foreach (var (index, constant) in enumerate(constants))
{
var value = constant_op.constant(constant, name: node_def.Input[index]);
var new_inputs = inputs.ToList();
new_inputs.Insert(index, value);
inputs = new Tensors(new_inputs.ToArray());
}

var graph = inputs.graph;
var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]);
var op = graph._create_op_from_tf_operation(c_op);
op._control_flow_post_processing();

// Record the gradient because custom-made ops don't go through the
// code-gen'd eager call path
var op_type = op.node_def.Op;

tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs);

return op.outputs;
}
}
}

+ 4
- 1
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -1,4 +1,7 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;


+ 1
- 2
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
@@ -47,7 +47,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="SharpZipLib" Version="1.3.1" />
<PackageReference Include="ShellProgressBar" Version="5.0.0" />
</ItemGroup>

<ItemGroup>


+ 16
- 1
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -18,6 +18,7 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -151,7 +152,7 @@ namespace Tensorflow.Keras.Utils

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
@@ -164,6 +165,20 @@ namespace Tensorflow.Keras.Utils
}
}

static Layer GetLayer<T>(LayerArgs args)
{
Layer layer = default;
var assemble = Assembly.Load("TensorFlow.Keras.Layers");
foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null))
{
layer = (Layer)Activator.CreateInstance(type, new object[] { args });
}

if (layer == null)
throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}");
return layer;
}

// recusive
static bool uses_keras_history(Tensor op_input)
{


Loading…
Cancel
Save