@@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||||
{ | { | ||||
IntPtr _handle; | IntPtr _handle; | ||||
FuncGraph func_graph; | FuncGraph func_graph; | ||||
public Tensor[] Inputs => func_graph.Inputs; | |||||
public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
public string Name | public string Name | ||||
@@ -127,30 +128,53 @@ namespace Tensorflow.Functions | |||||
func_graph.Exit(); | func_graph.Exit(); | ||||
} | } | ||||
public Tensors Invoke(Tensors inputs) | |||||
public Tensors FilteredCall(Tensors inputs) | |||||
{ | { | ||||
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||||
Tensors flat_outputs = null; | |||||
if (tf.Context.executing_eagerly()) | |||||
flat_outputs = forward_function.Call(args_with_tangents); | |||||
forward_backward.Record(flat_outputs); | |||||
return flat_outputs; | |||||
return CallFlat(inputs, CapturedInputs); | |||||
} | } | ||||
/// <summary> | |||||
/// Executes the wrapped function. | |||||
/// </summary> | |||||
/// <param name="args"></param> | |||||
/// <param name="captured_inputs"></param> | |||||
/// <returns></returns> | |||||
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | ||||
{ | { | ||||
var new_args = new List<Tensor>(); | |||||
new_args.AddRange(args); | |||||
new_args.AddRange(captured_inputs); | |||||
args = new_args.ToArray(); | |||||
var executing_eagerly = tf.Context.executing_eagerly(); | |||||
var default_graph = ops.get_default_graph(); | |||||
var tensor_inputs = new Tensors(); | |||||
foreach (var (i, arg) in enumerate(args)) | |||||
{ | |||||
tensor_inputs.Add(arg); | |||||
// If we're graph building, shape inference is on. | |||||
if (!executing_eagerly) | |||||
{ | |||||
} | |||||
} | |||||
tensor_inputs.AddRange(captured_inputs); | |||||
args = tensor_inputs.ToArray(); | |||||
var attrs = new object[] | |||||
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||||
// No tape is watching; skip to running the function. | |||||
if (possible_gradient_type == 0 && executing_eagerly) | |||||
{ | { | ||||
"executor_type", "", | |||||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
}; | |||||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||||
var attrs = new object[] | |||||
{ | |||||
"executor_type", "", | |||||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
}; | |||||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||||
} | |||||
var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||||
Tensors flat_outputs = null; | |||||
if (executing_eagerly) | |||||
flat_outputs = forward_function.Call(args_with_tangents); | |||||
forward_backward.Record(flat_outputs); | |||||
return flat_outputs; | |||||
} | } | ||||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
@@ -31,11 +31,17 @@ namespace Tensorflow.Functions | |||||
public Tensors Call(Tensors args) | public Tensors Call(Tensors args) | ||||
{ | { | ||||
var attrs = new object[] | |||||
{ | |||||
"executor_type", "", | |||||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
}; | |||||
var results = tf.Runner.TFE_Execute(tf.Context, | var results = tf.Runner.TFE_Execute(tf.Context, | ||||
tf.Context.DeviceName, | tf.Context.DeviceName, | ||||
_func_graph.FuncName, | _func_graph.FuncName, | ||||
args, | args, | ||||
null, | |||||
attrs, | |||||
_num_outputs); | _num_outputs); | ||||
return results; | return results; | ||||
@@ -49,24 +49,61 @@ namespace Tensorflow.Functions | |||||
getBackwardFunction: () => backward_function); | getBackwardFunction: () => backward_function); | ||||
} | } | ||||
/// <summary> | |||||
/// Create a backward function given `outputs` from the forward function. | |||||
/// </summary> | |||||
/// <param name="forward_graph"></param> | |||||
/// <param name="backward"></param> | |||||
/// <param name="outputs"></param> | |||||
/// <returns></returns> | |||||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | ||||
{ | { | ||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
var capture_mapping = new Dictionary<long, Tensor>(); | |||||
foreach(var (i, output) in enumerate(outputs)) | |||||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||||
var remapped_captures = new Tensors(); | |||||
foreach(var capture in backward.CapturedInputs) | |||||
{ | |||||
if (capture_mapping.ContainsKey(capture.Id)) | |||||
remapped_captures.Add(capture_mapping[capture.Id]); | |||||
} | |||||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||||
var recorded_outputs = new Tensors(); | |||||
var relevant_outputs = outputs; | |||||
var trainable_recorded_outputs = 0; | |||||
var skip_positions = new List<int>(); | |||||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||||
{ | |||||
if (trainable_recorded_outputs < backward_function_inputs) | |||||
recorded_outputs.Add(output); | |||||
if (gradients_util.IsTrainable(output)) | |||||
trainable_recorded_outputs += 1; | |||||
else | |||||
skip_positions.Add(output_index); | |||||
} | |||||
BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => | |||||
{ | { | ||||
var processed_args = new List<Tensor>(); | |||||
var processed_args = new Tensors(); | |||||
var input_index = 0; | var input_index = 0; | ||||
foreach (var (output_index, arg) in enumerate(output_grads)) | |||||
foreach (var (output_index, arg) in enumerate(args)) | |||||
{ | { | ||||
if (arg is null) | |||||
if (skip_positions.Contains(output_index)) | |||||
continue; | |||||
if (arg == null) | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
processed_args.add(arg); | |||||
processed_args.Add(arg); | |||||
input_index += 1; | input_index += 1; | ||||
if (input_index >= backward_function_inputs) | |||||
break; | |||||
} | } | ||||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | ||||
return backward.CallFlat(processed_args.ToArray(), outputs); | |||||
return backward.CallFlat(processed_args, remapped_captures); | |||||
}; | }; | ||||
return (_backward_function_wrapper, outputs); | |||||
return (_backward_function_wrapper, recorded_outputs); | |||||
} | } | ||||
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | ||||
@@ -103,7 +140,7 @@ namespace Tensorflow.Functions | |||||
} | } | ||||
backwards_graph.Exit(); | backwards_graph.Exit(); | ||||
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||||
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||||
var backward_function_attr = new Dictionary<string, string>(); | var backward_function_attr = new Dictionary<string, string>(); | ||||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | ||||
gradients_wrt_outputs.append(backwards_graph.internal_captures); | gradients_wrt_outputs.append(backwards_graph.internal_captures); | ||||
@@ -228,13 +228,14 @@ namespace Tensorflow.Gradients | |||||
var grad = grads[0]; | var grad = grads[0]; | ||||
var x = op.inputs[0]; | var x = op.inputs[0]; | ||||
var a = op.inputs[1]; | var a = op.inputs[1]; | ||||
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||||
var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) }); | |||||
var begin = constant_op.constant(new[] { 0, 0 }); | |||||
var pad_before = array_ops.slice(a, begin, size); | |||||
// Make it a 1-D tensor. | // Make it a 1-D tensor. | ||||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||||
var sizes = array_ops.shape(x); | |||||
var x_grad = array_ops.slice(grad, begin, sizes); | |||||
begin = array_ops.reshape(pad_before, new[] { -1 }); | |||||
size = array_ops.shape(x); | |||||
var x_grad = array_ops.slice(grad, begin, size); | |||||
if (len(op.inputs) == 3) | if (len(op.inputs) == 3) | ||||
return new Tensor[] { x_grad, null, null }; | return new Tensor[] { x_grad, null, null }; | ||||
@@ -30,7 +30,7 @@ namespace Tensorflow.Gradients | |||||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | ||||
Tensor image_shape = null; | Tensor image_shape = null; | ||||
if (shape.is_fully_defined()) | if (shape.is_fully_defined()) | ||||
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||||
image_shape = constant_op.constant(image.shape[1..3]); | |||||
else | else | ||||
image_shape = array_ops.shape(image)["1:3"]; | image_shape = array_ops.shape(image)["1:3"]; | ||||
@@ -8,6 +8,9 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Graphs | namespace Tensorflow.Graphs | ||||
{ | { | ||||
/// <summary> | |||||
/// func_graph.py func_graph_from_py_func | |||||
/// </summary> | |||||
[AllowChangingInputArguments] | [AllowChangingInputArguments] | ||||
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | ||||
{ | { | ||||
@@ -18,15 +21,16 @@ namespace Tensorflow.Graphs | |||||
public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
{ | { | ||||
func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; | |||||
// TODO: func_name can be cache in FullName + Args | |||||
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}"; | |||||
if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
{ | { | ||||
function = functions[func_name]; | function = functions[func_name]; | ||||
if (args.Arguments[0] is Tensors tensor_inputs) | if (args.Arguments[0] is Tensors tensor_inputs) | ||||
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); | |||||
else | else | ||||
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); | |||||
args.FlowBehavior = FlowBehavior.Return; | args.FlowBehavior = FlowBehavior.Return; | ||||
return; | return; | ||||
} | } | ||||
@@ -62,14 +66,27 @@ namespace Tensorflow.Graphs | |||||
{ | { | ||||
if (args.ReturnValue is Tensors outputs) | if (args.ReturnValue is Tensors outputs) | ||||
{ | { | ||||
if (args.Arguments[0] is Tensors inputs) | |||||
function.ToGraph(inputs, outputs); | |||||
Tensors inputs = null; | |||||
outputs = mark_as_return(outputs); | |||||
if (args.Arguments[0] is Tensors inputs1) | |||||
inputs = inputs1; | |||||
else | else | ||||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||||
inputs = args.Arguments.Select(x => x as Tensor).ToArray(); | |||||
inputs = inputs.Where(x => x.op.OpType == "Placeholder" | |||||
&& x.op.name.StartsWith("inputs")).ToArray(); | |||||
function.ToGraph(inputs, outputs); | |||||
} | } | ||||
else | |||||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||||
else if (args.ReturnValue is Tensor output) | |||||
{ | |||||
var inputs = args.Arguments.Select(x => x as Tensor) | |||||
.Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) | |||||
.ToArray(); | |||||
var outputs2 = array_ops.identity(output); | |||||
function.ToGraph(inputs, outputs2); | |||||
} | |||||
function.Exit(); | function.Exit(); | ||||
// cache function. | // cache function. | ||||
@@ -77,7 +94,7 @@ namespace Tensorflow.Graphs | |||||
functions[func_name] = function; | functions[func_name] = function; | ||||
// run function | // run function | ||||
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||||
args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); | |||||
} | } | ||||
object ConvertReturnValue(Tensors tensors) | object ConvertReturnValue(Tensors tensors) | ||||
@@ -87,5 +104,20 @@ namespace Tensorflow.Graphs | |||||
else | else | ||||
return tensors; | return tensors; | ||||
} | } | ||||
/// <summary> | |||||
/// Acts like identity but marks the `Tensor` as a return value. | |||||
/// </summary> | |||||
/// <param name="tensors"></param> | |||||
/// <returns></returns> | |||||
public Tensors mark_as_return(Tensors tensors) | |||||
{ | |||||
if (tensors == null) | |||||
return null; | |||||
var result = new Tensors(); | |||||
foreach (var tensor in tensors) | |||||
result.Add(array_ops.identity(tensor)); | |||||
return result; | |||||
} | |||||
} | } | ||||
} | } |
@@ -925,7 +925,28 @@ namespace Tensorflow | |||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | ||||
=> gen_array_ops.slice(input, begin, size, name: name); | => gen_array_ops.slice(input, begin, size, name: name); | ||||
public static Tensor stack(object values, int axis = 0, string name = "stack") | |||||
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||||
=> tf.Context.RunInAutoMode2( | |||||
() => tf.OpDefLib._apply_op_helper("Slice", name, new | |||||
{ | |||||
input, begin, size | |||||
}).output, | |||||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"Slice", name, | |||||
null, | |||||
input, begin, size).FirstOrDefault(), | |||||
(op) => | |||||
{ | |||||
var attrs = new object[] | |||||
{ | |||||
"T", op.get_attr<TF_DataType>("T"), | |||||
"Index", op.get_attr<int>("Index") | |||||
}; | |||||
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); | |||||
}, | |||||
new Tensors(input, begin, size)); | |||||
public static Tensor stack(object values, int axis = 0, string name = "stack") | |||||
{ | { | ||||
if (axis == 0) | if (axis == 0) | ||||
// If the input is a constant list, it can be converted to a constant op | // If the input is a constant list, it can be converted to a constant op | ||||
@@ -238,18 +238,32 @@ namespace Tensorflow | |||||
"half_pixel_centers", half_pixel_centers).FirstOrDefault(), | "half_pixel_centers", half_pixel_centers).FirstOrDefault(), | ||||
images); | images); | ||||
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||||
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, | |||||
bool half_pixel_centers = false, string name = null) | bool half_pixel_centers = false, string name = null) | ||||
{ | |||||
var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||||
{ | |||||
grads, | |||||
size, | |||||
align_corners, | |||||
half_pixel_centers | |||||
}); | |||||
return op.output; | |||||
} | |||||
=> tf.Context.RunInAutoMode2( | |||||
() => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new | |||||
{ | |||||
grads, | |||||
size, | |||||
align_corners, | |||||
half_pixel_centers | |||||
}).output, | |||||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"ResizeNearestNeighborGrad", name, | |||||
null, | |||||
grads, size, | |||||
"align_corners", align_corners, | |||||
"half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||||
(op) => | |||||
{ | |||||
var attrs = new object[] | |||||
{ | |||||
"T", op.get_attr<TF_DataType>("T"), | |||||
"align_corners", op.get_attr<bool>("align_corners"), | |||||
"half_pixel_centers", op.get_attr<bool>("half_pixel_centers") | |||||
}; | |||||
tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs); | |||||
}, | |||||
new Tensors(grads, size)); | |||||
} | } | ||||
} | } |
@@ -126,6 +126,16 @@ namespace Tensorflow | |||||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | ||||
string name = null) | string name = null) | ||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"RandomShuffle", name, | |||||
null, | |||||
value, seed, seed2); | |||||
return results[0]; | |||||
} | |||||
var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", | var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", | ||||
name: name, | name: name, | ||||
args: new { value, seed, seed2 }); | args: new { value, seed, seed2 }); | ||||
@@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including: | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | ||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" /> | |||||
<PackageReference Include="NumSharp.Lite" Version="0.1.10" /> | <PackageReference Include="NumSharp.Lite" Version="0.1.10" /> | ||||
<PackageReference Include="Protobuf.Text" Version="0.4.0" /> | <PackageReference Include="Protobuf.Text" Version="0.4.0" /> | ||||
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | <PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | ||||
@@ -57,6 +57,9 @@ namespace Tensorflow | |||||
public void Add(Tensor tensor) | public void Add(Tensor tensor) | ||||
=> items.Add(tensor); | => items.Add(tensor); | ||||
public void AddRange(Tensor[] tensors) | |||||
=> items.AddRange(tensors); | |||||
IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -48,7 +48,7 @@ namespace Tensorflow | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
Logger = new LoggerConfiguration() | Logger = new LoggerConfiguration() | ||||
.MinimumLevel.Error() | |||||
.MinimumLevel.Warning() | |||||
.WriteTo.Console() | .WriteTo.Console() | ||||
.CreateLogger(); | .CreateLogger(); | ||||
@@ -16,6 +16,7 @@ | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Linq; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
@@ -197,7 +198,7 @@ namespace Tensorflow.Keras | |||||
} | } | ||||
if (outputs[0].op.type == "Placeholder" | if (outputs[0].op.type == "Placeholder" | ||||
|| outputs[0].op.type == "StridedSlice") | || outputs[0].op.type == "StridedSlice") | ||||
return exec_graph.external_captures[0].numpy(); | |||||
return exec_graph.external_captures.Last().numpy(); | |||||
// Consolidate updates | // Consolidate updates | ||||
exec_graph.as_default(); | exec_graph.as_default(); | ||||
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public interface ITensorFlowOpLayer | |||||
{ | |||||
Layer GetOpLayer(TensorFlowOpLayerArgs args); | |||||
} | |||||
} |
@@ -1,5 +1,4 @@ | |||||
using NumSharp; | using NumSharp; | ||||
using ShellProgressBar; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
@@ -88,15 +87,8 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
stop_training = false; | stop_training = false; | ||||
_train_counter.assign(0); | _train_counter.assign(0); | ||||
var options = new ProgressBarOptions | |||||
{ | |||||
ProgressCharacter = '.', | |||||
ProgressBarOnBottom = true | |||||
}; | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
{ | { | ||||
using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); | |||||
// reset_metrics(); | // reset_metrics(); | ||||
// callbacks.on_epoch_begin(epoch) | // callbacks.on_epoch_begin(epoch) | ||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
@@ -105,7 +97,7 @@ namespace Tensorflow.Keras.Engine | |||||
// callbacks.on_train_batch_begin(step) | // callbacks.on_train_batch_begin(step) | ||||
var results = step_function(iterator); | var results = step_function(iterator); | ||||
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | ||||
pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||||
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -1,66 +0,0 @@ | |||||
using NumSharp; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Graphs; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class TensorFlowOpLayer : Layer | |||||
{ | |||||
TensorFlowOpLayerArgs args; | |||||
Dictionary<int, NDArray> constants => args.Constants; | |||||
NodeDef node_def => args.NodeDef; | |||||
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; | |||||
public string OpType => node_def.Op; | |||||
public TensorFlowOpLayer(TensorFlowOpLayerArgs args) | |||||
: base(new LayerArgs | |||||
{ | |||||
Name = TF_OP_LAYER_NAME_PREFIX + args.Name, | |||||
Trainable = args.Trainable, | |||||
DType = args.DType, | |||||
Autocast = false | |||||
}) | |||||
{ | |||||
this.args = args; | |||||
built = true; | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
return _defun_call(inputs); | |||||
return MakOp(inputs); | |||||
} | |||||
[AutoGraph] | |||||
Tensors _defun_call(Tensors inputs) | |||||
=> MakOp(inputs); | |||||
Tensors MakOp(Tensors inputs) | |||||
{ | |||||
foreach (var (index, constant) in enumerate(constants)) | |||||
{ | |||||
var value = constant_op.constant(constant, name: node_def.Input[index]); | |||||
var new_inputs = inputs.ToList(); | |||||
new_inputs.Insert(index, value); | |||||
inputs = new Tensors(new_inputs.ToArray()); | |||||
} | |||||
var graph = inputs.graph; | |||||
var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); | |||||
var op = graph._create_op_from_tf_operation(c_op); | |||||
op._control_flow_post_processing(); | |||||
// Record the gradient because custom-made ops don't go through the | |||||
// code-gen'd eager call path | |||||
var op_type = op.node_def.Op; | |||||
tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||||
return op.outputs; | |||||
} | |||||
} | |||||
} |
@@ -1,4 +1,7 @@ | |||||
using System.Collections.Generic; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Reflection; | |||||
using System.Linq; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Datasets; | using Tensorflow.Keras.Datasets; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
@@ -1,4 +1,4 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | |||||
<Project Sdk="Microsoft.NET.Sdk"> | |||||
<PropertyGroup> | <PropertyGroup> | ||||
<TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
@@ -47,7 +47,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> | ||||
<PackageReference Include="SharpZipLib" Version="1.3.1" /> | <PackageReference Include="SharpZipLib" Version="1.3.1" /> | ||||
<PackageReference Include="ShellProgressBar" Version="5.0.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -18,6 +18,7 @@ using NumSharp; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -151,7 +152,7 @@ namespace Tensorflow.Keras.Utils | |||||
// recursively | // recursively | ||||
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | ||||
Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs | |||||
{ | { | ||||
NodeDef = op.node_def, | NodeDef = op.node_def, | ||||
Constants = constants, | Constants = constants, | ||||
@@ -164,6 +165,20 @@ namespace Tensorflow.Keras.Utils | |||||
} | } | ||||
} | } | ||||
static Layer GetLayer<T>(LayerArgs args) | |||||
{ | |||||
Layer layer = default; | |||||
var assemble = Assembly.Load("TensorFlow.Keras.Layers"); | |||||
foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) | |||||
{ | |||||
layer = (Layer)Activator.CreateInstance(type, new object[] { args }); | |||||
} | |||||
if (layer == null) | |||||
throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); | |||||
return layer; | |||||
} | |||||
// recusive | // recusive | ||||
static bool uses_keras_history(Tensor op_input) | static bool uses_keras_history(Tensor op_input) | ||||
{ | { | ||||