@@ -98,35 +98,23 @@ namespace Tensorflow | |||
default: | |||
return obj?.ToString() ?? "null"; | |||
} | |||
object[] toObjectArray(Array arr) | |||
{ | |||
var len = arr.LongLength; | |||
var ret = new object[len]; | |||
for (long i = 0; i < len; i++) | |||
{ | |||
ret[i] = arr.GetValue(i); | |||
} | |||
return ret; | |||
} | |||
} | |||
private static TextWriter writer = null; | |||
private static TextWriter _writer = Console.Out; | |||
public static TextWriter tf_output_redirect { | |||
set | |||
{ | |||
var originWriter = writer ?? Console.Out; | |||
originWriter.Flush(); | |||
if (originWriter is StringWriter) | |||
(originWriter as StringWriter).GetStringBuilder().Clear(); | |||
writer = value; | |||
} | |||
get | |||
{ | |||
return writer ?? Console.Out; | |||
if(_writer != null) | |||
{ | |||
_writer.Flush(); | |||
if (_writer is StringWriter sw) | |||
sw.GetStringBuilder().Clear(); | |||
} | |||
_writer = value; | |||
} | |||
get => _writer ?? Console.Out; | |||
} | |||
public static void print(object obj) | |||
@@ -48,7 +48,7 @@ namespace Tensorflow | |||
} | |||
// free unmanaged memory | |||
// if (_handle != IntPtr.Zero) | |||
if (_handle != IntPtr.Zero) | |||
{ | |||
// Call the appropriate methods to clean up | |||
// unmanaged resources here. | |||
@@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine | |||
/// <returns></returns> | |||
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>() | |||
{ | |||
Value = new CallContext() | |||
}; | |||
if (callContext.Value == null) | |||
callContext.Value = new CallContext(); | |||
if (_in_functional_construction_mode(inputs)) | |||
return FunctionalConstructionCall(inputs); | |||
Tensors outputs = null; | |||
var eager = tf.executing_eagerly(); | |||
using var ctxManager = CallContext.enter(build_graph: false); | |||
string nameScope = ""; | |||
if (eager) | |||
nameScope = Name; | |||
else | |||
nameScope = _name_scope(); | |||
string nameScope = eager ? name : _name_scope(); | |||
var scope = ops.name_scope(nameScope); | |||
scope.__enter__(); | |||
if (!built) | |||
MaybeBuild(inputs); | |||
outputs = Call(inputs, state: state, training: training); | |||
var outputs = Call(inputs, state: state, training: training); | |||
// memory leak | |||
// _set_connectivity_metadata_(inputs, outputs); | |||
@@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine | |||
List<INode> outboundNodes; | |||
public List<INode> OutboundNodes => outboundNodes; | |||
ThreadLocal<CallContext> callContext; | |||
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
public CallContext CallContext => callContext.Value; | |||
public Tensor[] input => inboundNodes[0].input_tensors; | |||
public Dictionary<int, List<INode>> NodesByDepth { get; set; } | |||
public Shape output_shape => inboundNodes[0].Outputs.shape; | |||
protected List<ILayer> _self_tracked_trackables; | |||
public Layer(LayerArgs args) | |||
{ | |||
this.args = args; | |||
@@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine | |||
non_trainable_weights = new List<IVariableV1>(); | |||
computePreviousMask = false; | |||
updates = new List<Operation>(); | |||
_self_tracked_trackables = new List<ILayer>(); | |||
inboundNodes = new List<INode>(); | |||
outboundNodes = new List<INode>(); | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
@@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine | |||
bool _auto_track_sub_layers; | |||
Shape _inferred_input_shape; | |||
bool _has_explicit_input_shape; | |||
bool _graph_initialized; | |||
public Shape output_shape => outputs[0].shape; | |||
List<INode> _created_nodes; | |||
public Sequential(SequentialArgs args) | |||
: base(args.Inputs, args.Outputs, name: args.Name) | |||
@@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine | |||
_auto_track_sub_layers = false; | |||
_has_explicit_input_shape = false; | |||
_is_graph_network = false; | |||
_created_nodes = new List<INode>(); | |||
// Add to the model any layers passed to the constructor. | |||
if (args.Layers != null) | |||
{ | |||
foreach (var layer in args.Layers) | |||
add(layer as Layer); | |||
add(layer); | |||
} | |||
} | |||
@@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
else | |||
{ | |||
_self_tracked_trackables.add(layer); | |||
_handle_deferred_layer_dependencies(layer); | |||
} | |||
} | |||
void _handle_deferred_layer_dependencies(params ILayer[] layers) | |||
{ | |||
_layers.AddRange(layers); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
if (!_has_explicit_input_shape) | |||
{ | |||
_build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype); | |||
} | |||
if(_graph_initialized) | |||
{ | |||
if (!built) | |||
_init_graph_network(this.inputs, outputs); | |||
return base.Call(inputs, state, training); | |||
} | |||
return base.Call(inputs, state, training); | |||
} | |||
void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) | |||
{ | |||
ops.init_scope(); | |||
var inputs = keras.Input(batch_input_shape: input_shape, | |||
dtype: input_dtype, | |||
name: $"{_layers[0].Name}_input"); | |||
Tensors layer_input = inputs; | |||
Tensors layer_output = null; | |||
Tensors outputs = null; | |||
foreach (var layer in _layers) | |||
{ | |||
clear_previously_created_nodes(layer, _created_nodes); | |||
layer_output = layer.Apply(layer_input); | |||
// Keep track of nodes just created above | |||
track_nodes_created_by_last_call(layer, _created_nodes); | |||
layer_input = layer_output; | |||
outputs = layer_output; | |||
} | |||
_init_graph_network(inputs, outputs); | |||
_graph_initialized = true; | |||
_inferred_input_shape = input_shape; | |||
} | |||
void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes) | |||
{ | |||
} | |||
void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes) | |||
{ | |||
var node = layer.InboundNodes.Last(); | |||
created_nodes.Add(node); | |||
foreach(var prev_layer in node.iterate_inbound()) | |||
{ | |||
created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); | |||
} | |||
} | |||
} | |||
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers | |||
var rank = inputs.rank; | |||
if (rank > 2) | |||
{ | |||
throw new NotImplementedException("call rank > 2"); | |||
outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } }); | |||
} | |||
else | |||
{ | |||