@@ -13,10 +13,8 @@ namespace Tensorflow | |||||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||||
public async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||||
{ | { | ||||
var loader = new MnistModelLoader(); | |||||
var setting = new ModelLoadSetting | var setting = new ModelLoadSetting | ||||
{ | { | ||||
TrainDir = trainDir, | TrainDir = trainDir, | ||||
@@ -33,7 +31,7 @@ namespace Tensorflow | |||||
if (testSize.HasValue) | if (testSize.HasValue) | ||||
setting.TestSize = testSize.Value; | setting.TestSize = testSize.Value; | ||||
return await loader.LoadAsync(setting); | |||||
return await LoadAsync(setting); | |||||
} | } | ||||
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow.Eager | |||||
int num_outputs) | int num_outputs) | ||||
{ | { | ||||
var status = tf.Status; | var status = tf.Status; | ||||
using var op = GetOp(ctx, op_name, status); | |||||
var op = GetOp(ctx, op_name, status); | |||||
c_api.TFE_OpSetDevice(op, device_name, status.Handle); | c_api.TFE_OpSetDevice(op, device_name, status.Handle); | ||||
if (status.ok()) | if (status.ok()) | ||||
{ | { | ||||
@@ -15,7 +15,9 @@ namespace Tensorflow.Eager | |||||
/// </summary> | /// </summary> | ||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>(); | |||||
UnorderedMap<string, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeOpHandle>(); | |||||
public void ClearEagerOperationMap() | |||||
=> thread_local_eager_operation_map.Clear(); | |||||
public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info) | public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info) | ||||
{ | { | ||||
@@ -31,7 +33,7 @@ namespace Tensorflow.Eager | |||||
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; | op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; | ||||
var status = tf.Status; | var status = tf.Status; | ||||
using var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); | |||||
var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); | |||||
var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name); | var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name); | ||||
@@ -56,8 +58,8 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); | |||||
status.Check(true); | |||||
// c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); | |||||
// status.Check(true); | |||||
// Add inferred attrs and inputs. | // Add inferred attrs and inputs. | ||||
for (int i = 0; i < op_def.InputArg.Count; i++) | for (int i = 0; i < op_def.InputArg.Count; i++) | ||||
@@ -145,7 +147,6 @@ namespace Tensorflow.Eager | |||||
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | ||||
if (op_exec_info.run_callbacks) | if (op_exec_info.run_callbacks) | ||||
{ | { | ||||
RunCallbacks(op_exec_info, | RunCallbacks(op_exec_info, | ||||
@@ -158,19 +159,19 @@ namespace Tensorflow.Eager | |||||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | ||||
{ | { | ||||
/*if (thread_local_eager_operation_map.find(ctx, out var op)) | |||||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | |||||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | ||||
else | else | ||||
{ | { | ||||
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | ||||
thread_local_eager_operation_map[ctx] = op; | |||||
thread_local_eager_operation_map[op_or_function_name] = op; | |||||
} | } | ||||
status.Check(true); | |||||
return op;*/ | |||||
var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||||
status.Check(true); | status.Check(true); | ||||
return op; | return op; | ||||
/*var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||||
status.Check(true); | |||||
return op;*/ | |||||
} | } | ||||
bool HasAccumulator() | bool HasAccumulator() | ||||
@@ -268,16 +269,7 @@ namespace Tensorflow.Eager | |||||
if (attr_value == null) | if (attr_value == null) | ||||
{ | { | ||||
if (is_list != 0) | |||||
#pragma warning disable CS0642 // Possible mistaken empty statement | |||||
; | |||||
#pragma warning restore CS0642 // Possible mistaken empty statement | |||||
//SetOpAttrListDefault | |||||
else | |||||
#pragma warning disable CS0642 // Possible mistaken empty statement | |||||
; | |||||
#pragma warning restore CS0642 // Possible mistaken empty statement | |||||
//SetOpAttrScalarDefault | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -39,5 +39,7 @@ namespace Tensorflow.Eager | |||||
bool MustRecordGradient(); | bool MustRecordGradient(); | ||||
int TapeSetPossibleGradientTypes(params Tensor[] args); | int TapeSetPossibleGradientTypes(params Tensor[] args); | ||||
void ClearEagerOperationMap(); | |||||
} | } | ||||
} | } |
@@ -44,14 +44,14 @@ namespace Tensorflow.Framework | |||||
return true; | return true; | ||||
} | } | ||||
if (other.IsSparseTensor) | |||||
if (other is SparseTensor) | |||||
{ | { | ||||
return self.dtype.is_compatible_with(other.dtype); | return self.dtype.is_compatible_with(other.dtype); | ||||
} | } | ||||
return self.dtype.is_compatible_with(other.dtype) && | return self.dtype.is_compatible_with(other.dtype) && | ||||
_shape_is_compatible_0dim(self.shape, other.shape) && | _shape_is_compatible_0dim(self.shape, other.shape) && | ||||
!self.IsSparseTensor; | |||||
!(self is SparseTensor); | |||||
} | } | ||||
public static Dimension dimension_at_index(Shape shape, int index) | public static Dimension dimension_at_index(Shape shape, int index) | ||||
@@ -30,10 +30,6 @@ namespace Tensorflow | |||||
public class BaseSession : DisposableObject | public class BaseSession : DisposableObject | ||||
{ | { | ||||
protected Graph _graph; | protected Graph _graph; | ||||
protected bool _opened; | |||||
protected bool _closed; | |||||
protected int _current_version; | |||||
protected byte[] _target; | |||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
public BaseSession(IntPtr handle, Graph g) | public BaseSession(IntPtr handle, Graph g) | ||||
@@ -46,18 +42,15 @@ namespace Tensorflow | |||||
{ | { | ||||
_graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
if (!_graph.building_function) | if (!_graph.building_function) | ||||
_graph.as_default(); | |||||
_target = Encoding.UTF8.GetBytes(target); | |||||
using (var opts = new SessionOptions(target, config)) | |||||
{ | { | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
status = status ?? new Status(); | |||||
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||||
status.Check(true); | |||||
} | |||||
if (ops.get_default_graph() != _graph) | |||||
_graph.as_default(); | |||||
} | } | ||||
using var opts = new SessionOptions(target, config); | |||||
status = status ?? tf.Status; | |||||
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||||
status.Check(true); | |||||
} | } | ||||
public virtual void run(Operation op, params FeedItem[] feed_dict) | public virtual void run(Operation op, params FeedItem[] feed_dict) | ||||
@@ -26,80 +26,6 @@ namespace Tensorflow | |||||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | ||||
public partial class Tensor | public partial class Tensor | ||||
{ | { | ||||
public unsafe void CopyTo(NDArray nd) | |||||
{ | |||||
//if (!nd.Shape.IsContiguous) | |||||
//throw new ArgumentException("NDArray has to be contiguous (ndarray.Shape.IsContiguous)."); | |||||
var length = (int)(nd.size * nd.dtypesize); | |||||
switch (nd.dtype) | |||||
{ | |||||
/*case NumpyDType.Boolean: | |||||
{ | |||||
CopyTo(new Span<bool>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Byte: | |||||
{ | |||||
CopyTo(new Span<byte>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Int16: | |||||
{ | |||||
CopyTo(new Span<short>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.UInt16: | |||||
{ | |||||
CopyTo(new Span<ushort>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Int32: | |||||
{ | |||||
CopyTo(new Span<int>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.UInt32: | |||||
{ | |||||
CopyTo(new Span<uint>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Int64: | |||||
{ | |||||
CopyTo(new Span<long>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.UInt64: | |||||
{ | |||||
CopyTo(new Span<ulong>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Char: | |||||
{ | |||||
CopyTo(new Span<char>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Double: | |||||
{ | |||||
CopyTo(new Span<double>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
} | |||||
case NumpyDType.Single: | |||||
{ | |||||
CopyTo(new Span<float>(nd.Address.ToPointer(), length)); | |||||
break; | |||||
}*/ | |||||
default: | |||||
throw new NotSupportedException(); | |||||
} | |||||
} | |||||
public void CopyTo<T>(Span<T> destination) where T : unmanaged | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public TensorSpec ToTensorSpec() | public TensorSpec ToTensorSpec() | ||||
=> new TensorSpec(shape, dtype, name); | => new TensorSpec(shape, dtype, name); | ||||
} | } |
@@ -32,7 +32,7 @@ namespace Tensorflow | |||||
public Tensor() | public Tensor() | ||||
{ | { | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||||
if (clone && handle != null) | if (clone && handle != null) | ||||
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -59,13 +59,13 @@ namespace Tensorflow | |||||
public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) | public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) | ||||
{ | { | ||||
_handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); | _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
public unsafe Tensor(NDArray nd) | public unsafe Tensor(NDArray nd) | ||||
{ | { | ||||
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
#region scala | #region scala | ||||
@@ -107,13 +107,13 @@ namespace Tensorflow | |||||
_value_index = value_index; | _value_index = value_index; | ||||
_override_dtype = dtype; | _override_dtype = dtype; | ||||
_id = ops.uid(); | _id = ops.uid(); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | ||||
{ | { | ||||
_handle = TF_NewTensor(shape, dtype, null); | _handle = TF_NewTensor(shape, dtype, null); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | ||||
@@ -122,12 +122,12 @@ namespace Tensorflow | |||||
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | ||||
else | else | ||||
_handle = TF_NewTensor(bytes, shape, dtype); | _handle = TF_NewTensor(bytes, shape, dtype); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
protected unsafe void InitTensor(Array array, Shape? shape = null) | protected unsafe void InitTensor(Array array, Shape? shape = null) | ||||
{ | { | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
shape = shape ?? array.GetShape(); | shape = shape ?? array.GetShape(); | ||||
var dtype = array.GetDataType(); | var dtype = array.GetDataType(); | ||||
@@ -17,11 +17,8 @@ | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using System; | using System; | ||||
using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
using System.Globalization; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -97,12 +94,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; | public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; | ||||
protected bool isCreatedInGraphMode; | |||||
protected bool _isCreatedInGraphMode; | |||||
public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||||
public bool IsSparseTensor => this is SparseTensor; | |||||
public Tensor TensorShape => tf.shape(this); | |||||
public bool IsCreatedInGraphMode => _isCreatedInGraphMode; | |||||
/// <summary> | /// <summary> | ||||
/// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
@@ -157,7 +151,6 @@ namespace Tensorflow | |||||
/// Keras History: (Layer, (node_index, tensor_index)) | /// Keras History: (Layer, (node_index, tensor_index)) | ||||
/// </summary> | /// </summary> | ||||
public KerasHistory KerasHistory { get; set; } | public KerasHistory KerasHistory { get; set; } | ||||
public Tensor KerasMask { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||
@@ -383,6 +383,9 @@ namespace Tensorflow | |||||
public static void reset_uid() | public static void reset_uid() | ||||
{ | { | ||||
uid_number = -1; | uid_number = -1; | ||||
graph_uid_number = -1; | |||||
uid_number_for_function = 0; | |||||
uid_number_for_layer = 0; | |||||
} | } | ||||
public static void colocate_with(bool ignore_existing = false) | public static void colocate_with(bool ignore_existing = false) | ||||
@@ -126,9 +126,10 @@ namespace Tensorflow.Keras | |||||
PER_GRAPH_LAYER_NAME_UIDS.Clear(); | PER_GRAPH_LAYER_NAME_UIDS.Clear(); | ||||
_CURRENT_SCRATCH_GRAPH = null; | _CURRENT_SCRATCH_GRAPH = null; | ||||
_GRAPH = null; | _GRAPH = null; | ||||
ops.set_default_session(tf.Session(ops.get_default_graph())); | ops.set_default_session(tf.Session(ops.get_default_graph())); | ||||
tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
tf.Runner.ClearEagerOperationMap(); | |||||
GC.Collect(); | GC.Collect(); | ||||
GC.WaitForPendingFinalizers(); | GC.WaitForPendingFinalizers(); | ||||
@@ -12,11 +12,6 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
public partial class Functional : Model | public partial class Functional : Model | ||||
{ | { | ||||
Shape _build_input_shape; | |||||
bool _compute_output_and_mask_jointly; | |||||
bool _expects_training_arg; | |||||
bool _expects_mask_arg; | |||||
bool _autocast; | |||||
List<ILayer> _output_layers; | List<ILayer> _output_layers; | ||||
List<ILayer> _input_layers; | List<ILayer> _input_layers; | ||||
List<KerasHistory> _input_coordinates; | List<KerasHistory> _input_coordinates; | ||||
@@ -49,12 +44,6 @@ namespace Tensorflow.Keras.Engine | |||||
this.inputs = inputs; | this.inputs = inputs; | ||||
this.outputs = outputs; | this.outputs = outputs; | ||||
built = true; | built = true; | ||||
_build_input_shape = inputs.shape; | |||||
_compute_output_and_mask_jointly = true; | |||||
_expects_training_arg = true; | |||||
_expects_mask_arg = true; | |||||
// A graph network does not autocast inputs, as its layers will cast them instead. | |||||
_autocast = false; | |||||
if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
@@ -303,23 +292,11 @@ namespace Tensorflow.Keras.Engine | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
{ | { | ||||
return run_internal_graph(inputs, training.Value); | |||||
} | |||||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||||
{ | |||||
if (mask == null) | |||||
{ | |||||
Tensor[] masks = new Tensor[inputs.Count()]; | |||||
foreach (var (i, input_t) in enumerate(inputs)) | |||||
input_t.KerasMask = masks[i]; | |||||
} | |||||
var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | ||||
// map input values | |||||
foreach (var (x, y) in zip(this.inputs, inputs)) | foreach (var (x, y) in zip(this.inputs, inputs)) | ||||
{ | { | ||||
var y1 = conform_to_reference_input(y, x); | |||||
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1)); | |||||
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y)); | |||||
} | } | ||||
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); | var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); | ||||
@@ -336,11 +313,11 @@ namespace Tensorflow.Keras.Engine | |||||
var layer_inputs = node.MapArguments(tensor_dict); | var layer_inputs = node.MapArguments(tensor_dict); | ||||
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | ||||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||||
var outputs = node.Layer.Apply(layer_inputs, is_training: training ?? false); | |||||
foreach (var output in outputs.Where(x => x != null)) | foreach (var output in outputs.Where(x => x != null)) | ||||
tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); | tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); | ||||
// Update tensor_dict for next input | // Update tensor_dict for next input | ||||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||||
foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs)) | |||||
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | ||||
} | } | ||||
} | } | ||||
@@ -352,10 +329,5 @@ namespace Tensorflow.Keras.Engine | |||||
return output_tensors; | return output_tensors; | ||||
} | } | ||||
Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||||
{ | |||||
return tensor; | |||||
} | |||||
} | } | ||||
} | } |
@@ -11,7 +11,6 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public partial class Model | public partial class Model | ||||
{ | { | ||||
List<(IVariableV1, NDArray)> LoadedWeights; | |||||
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) | public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) | ||||
{ | { | ||||
long fileId = Hdf5.OpenFile(filepath, true); | long fileId = Hdf5.OpenFile(filepath, true); | ||||
@@ -31,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
else | else | ||||
{ | { | ||||
LoadedWeights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||||
hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||||
Hdf5.CloseFile(fileId); | Hdf5.CloseFile(fileId); | ||||
} | } | ||||
} | } | ||||
@@ -33,20 +33,13 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
NodeArgs args; | NodeArgs args; | ||||
public int[] node_indices; | |||||
public int[] tensor_indices; | |||||
public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | ||||
public Tensors Outputs => args.Outputs; | public Tensors Outputs => args.Outputs; | ||||
public Shape[] input_shapes; | |||||
public Shape[] output_shapes; | |||||
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | ||||
ILayer _layer; | ILayer _layer; | ||||
public ILayer Layer => _layer; | public ILayer Layer => _layer; | ||||
public bool is_input => args.InputTensors == null; | public bool is_input => args.InputTensors == null; | ||||
public long[] FlatInputIds { get; set; } | |||||
public long[] FlatOutputIds { get; set; } | |||||
bool _single_positional_tensor_passed => KerasInputs.Count() == 1; | |||||
Dictionary<int, long> _keras_inputs_ids_and_indices = new Dictionary<int, long>(); | |||||
public INode[] ParentNodes | public INode[] ParentNodes | ||||
{ | { | ||||
get | get | ||||
@@ -74,9 +67,6 @@ namespace Tensorflow.Keras.Engine | |||||
if (args.InputTensors != null) | if (args.InputTensors != null) | ||||
KerasInputs.AddRange(args.InputTensors); | KerasInputs.AddRange(args.InputTensors); | ||||
foreach (var (i, ele) in enumerate(KerasInputs)) | |||||
_keras_inputs_ids_and_indices[i] = ele.Id; | |||||
// Wire up Node to Layers. | // Wire up Node to Layers. | ||||
layer.InboundNodes.Add(this); | layer.InboundNodes.Add(this); | ||||
@@ -93,10 +83,6 @@ namespace Tensorflow.Keras.Engine | |||||
var node_index = layer.InboundNodes.Count - 1; | var node_index = layer.InboundNodes.Count - 1; | ||||
foreach (var (i, tensor) in enumerate(Outputs)) | foreach (var (i, tensor) in enumerate(Outputs)) | ||||
tensor.KerasHistory = new KerasHistory(layer, node_index, i); | tensor.KerasHistory = new KerasHistory(layer, node_index, i); | ||||
// Cached for performance. | |||||
FlatInputIds = KerasInputs.Select(x => x.Id).ToArray(); | |||||
FlatOutputIds = Outputs.Select(x => x.Id).ToArray(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -106,16 +92,16 @@ namespace Tensorflow.Keras.Engine | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict) | public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict) | ||||
{ | { | ||||
if (_single_positional_tensor_passed) | |||||
if (KerasInputs.Count() == 1) | |||||
{ | { | ||||
var kt_id = _keras_inputs_ids_and_indices[0]; | |||||
var kt_id = KerasInputs[0].Id; | |||||
return tensor_dict[kt_id].Dequeue(); | return tensor_dict[kt_id].Dequeue(); | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
var flat_arguments = KerasInputs.Select(x => x).ToArray(); | var flat_arguments = KerasInputs.Select(x => x).ToArray(); | ||||
foreach (var (kt_index, kt_id) in enumerate(_keras_inputs_ids_and_indices)) | |||||
flat_arguments[kt_index] = tensor_dict[kt_id].Dequeue(); | |||||
foreach (var (kt_index, kt) in enumerate(KerasInputs)) | |||||
flat_arguments[kt_index] = tensor_dict[kt.Id].Dequeue(); | |||||
return flat_arguments; | return flat_arguments; | ||||
} | } | ||||
@@ -3,12 +3,10 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using HDF.PInvoke; | using HDF.PInvoke; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Keras.Engine; | |||||
using HDF5CSharp; | using HDF5CSharp; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
{ | { | ||||
public class hdf5_format | public class hdf5_format | ||||
@@ -82,7 +80,7 @@ namespace Tensorflow.Keras.Saving | |||||
} | } | ||||
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||||
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||||
{ | { | ||||
string original_keras_version = "2.5.0"; | string original_keras_version = "2.5.0"; | ||||
string original_backend = null; | string original_backend = null; | ||||
@@ -158,7 +156,6 @@ namespace Tensorflow.Keras.Saving | |||||
} | } | ||||
keras.backend.batch_set_value(weight_value_tuples); | keras.backend.batch_set_value(weight_value_tuples); | ||||
return weight_value_tuples; | |||||
} | } | ||||
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) | public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) | ||||
@@ -63,7 +63,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.139" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.139" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.1" /> | <PackageReference Include="Newtonsoft.Json" Version="13.0.1" /> | ||||
<PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" /> | <PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" /> | ||||
<PackageReference Include="SharpZipLib" Version="1.3.2" /> | |||||
<PackageReference Include="SharpZipLib" Version="1.3.3" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||