@@ -13,10 +13,8 @@ namespace Tensorflow | |||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
public async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
{ | |||
var loader = new MnistModelLoader(); | |||
var setting = new ModelLoadSetting | |||
{ | |||
TrainDir = trainDir, | |||
@@ -33,7 +31,7 @@ namespace Tensorflow | |||
if (testSize.HasValue) | |||
setting.TestSize = testSize.Value; | |||
return await loader.LoadAsync(setting); | |||
return await LoadAsync(setting); | |||
} | |||
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||
@@ -42,7 +42,7 @@ namespace Tensorflow.Eager | |||
int num_outputs) | |||
{ | |||
var status = tf.Status; | |||
using var op = GetOp(ctx, op_name, status); | |||
var op = GetOp(ctx, op_name, status); | |||
c_api.TFE_OpSetDevice(op, device_name, status.Handle); | |||
if (status.ok()) | |||
{ | |||
@@ -15,7 +15,9 @@ namespace Tensorflow.Eager | |||
/// </summary> | |||
public partial class EagerRunner | |||
{ | |||
UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>(); | |||
UnorderedMap<string, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeOpHandle>(); | |||
public void ClearEagerOperationMap() | |||
=> thread_local_eager_operation_map.Clear(); | |||
public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info) | |||
{ | |||
@@ -31,7 +33,7 @@ namespace Tensorflow.Eager | |||
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; | |||
var status = tf.Status; | |||
using var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); | |||
var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); | |||
var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name); | |||
@@ -56,8 +58,8 @@ namespace Tensorflow.Eager | |||
} | |||
} | |||
c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); | |||
status.Check(true); | |||
// c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); | |||
// status.Check(true); | |||
// Add inferred attrs and inputs. | |||
for (int i = 0; i < op_def.InputArg.Count; i++) | |||
@@ -145,7 +147,6 @@ namespace Tensorflow.Eager | |||
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | |||
if (op_exec_info.run_callbacks) | |||
{ | |||
RunCallbacks(op_exec_info, | |||
@@ -158,19 +159,19 @@ namespace Tensorflow.Eager | |||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
{ | |||
/*if (thread_local_eager_operation_map.find(ctx, out var op)) | |||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | |||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | |||
else | |||
{ | |||
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||
thread_local_eager_operation_map[ctx] = op; | |||
thread_local_eager_operation_map[op_or_function_name] = op; | |||
} | |||
status.Check(true); | |||
return op;*/ | |||
var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||
status.Check(true); | |||
return op; | |||
/*var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||
status.Check(true); | |||
return op;*/ | |||
} | |||
bool HasAccumulator() | |||
@@ -268,16 +269,7 @@ namespace Tensorflow.Eager | |||
if (attr_value == null) | |||
{ | |||
if (is_list != 0) | |||
#pragma warning disable CS0642 // Possible mistaken empty statement | |||
; | |||
#pragma warning restore CS0642 // Possible mistaken empty statement | |||
//SetOpAttrListDefault | |||
else | |||
#pragma warning disable CS0642 // Possible mistaken empty statement | |||
; | |||
#pragma warning restore CS0642 // Possible mistaken empty statement | |||
//SetOpAttrScalarDefault | |||
} | |||
else | |||
{ | |||
@@ -39,5 +39,7 @@ namespace Tensorflow.Eager | |||
bool MustRecordGradient(); | |||
int TapeSetPossibleGradientTypes(params Tensor[] args); | |||
void ClearEagerOperationMap(); | |||
} | |||
} |
@@ -44,14 +44,14 @@ namespace Tensorflow.Framework | |||
return true; | |||
} | |||
if (other.IsSparseTensor) | |||
if (other is SparseTensor) | |||
{ | |||
return self.dtype.is_compatible_with(other.dtype); | |||
} | |||
return self.dtype.is_compatible_with(other.dtype) && | |||
_shape_is_compatible_0dim(self.shape, other.shape) && | |||
!self.IsSparseTensor; | |||
!(self is SparseTensor); | |||
} | |||
public static Dimension dimension_at_index(Shape shape, int index) | |||
@@ -30,10 +30,6 @@ namespace Tensorflow | |||
public class BaseSession : DisposableObject | |||
{ | |||
protected Graph _graph; | |||
protected bool _opened; | |||
protected bool _closed; | |||
protected int _current_version; | |||
protected byte[] _target; | |||
public Graph graph => _graph; | |||
public BaseSession(IntPtr handle, Graph g) | |||
@@ -46,18 +42,15 @@ namespace Tensorflow | |||
{ | |||
_graph = g ?? ops.get_default_graph(); | |||
if (!_graph.building_function) | |||
_graph.as_default(); | |||
_target = Encoding.UTF8.GetBytes(target); | |||
using (var opts = new SessionOptions(target, config)) | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
status = status ?? new Status(); | |||
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||
status.Check(true); | |||
} | |||
if (ops.get_default_graph() != _graph) | |||
_graph.as_default(); | |||
} | |||
using var opts = new SessionOptions(target, config); | |||
status = status ?? tf.Status; | |||
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||
status.Check(true); | |||
} | |||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
@@ -26,80 +26,6 @@ namespace Tensorflow | |||
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] | |||
public partial class Tensor | |||
{ | |||
public unsafe void CopyTo(NDArray nd) | |||
{ | |||
//if (!nd.Shape.IsContiguous) | |||
//throw new ArgumentException("NDArray has to be contiguous (ndarray.Shape.IsContiguous)."); | |||
var length = (int)(nd.size * nd.dtypesize); | |||
switch (nd.dtype) | |||
{ | |||
/*case NumpyDType.Boolean: | |||
{ | |||
CopyTo(new Span<bool>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Byte: | |||
{ | |||
CopyTo(new Span<byte>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Int16: | |||
{ | |||
CopyTo(new Span<short>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.UInt16: | |||
{ | |||
CopyTo(new Span<ushort>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Int32: | |||
{ | |||
CopyTo(new Span<int>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.UInt32: | |||
{ | |||
CopyTo(new Span<uint>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Int64: | |||
{ | |||
CopyTo(new Span<long>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.UInt64: | |||
{ | |||
CopyTo(new Span<ulong>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Char: | |||
{ | |||
CopyTo(new Span<char>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Double: | |||
{ | |||
CopyTo(new Span<double>(nd.Address.ToPointer(), length)); | |||
break; | |||
} | |||
case NumpyDType.Single: | |||
{ | |||
CopyTo(new Span<float>(nd.Address.ToPointer(), length)); | |||
break; | |||
}*/ | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
} | |||
public void CopyTo<T>(Span<T> destination) where T : unmanaged | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public TensorSpec ToTensorSpec() | |||
=> new TensorSpec(shape, dtype, name); | |||
} |
@@ -32,7 +32,7 @@ namespace Tensorflow | |||
public Tensor() | |||
{ | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
/// <summary> | |||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||
if (clone && handle != null) | |||
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
/// <summary> | |||
@@ -59,13 +59,13 @@ namespace Tensorflow | |||
public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) | |||
{ | |||
_handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
public unsafe Tensor(NDArray nd) | |||
{ | |||
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
#region scala | |||
@@ -107,13 +107,13 @@ namespace Tensorflow | |||
_value_index = value_index; | |||
_override_dtype = dtype; | |||
_id = ops.uid(); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
{ | |||
_handle = TF_NewTensor(shape, dtype, null); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | |||
@@ -122,12 +122,12 @@ namespace Tensorflow | |||
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | |||
else | |||
_handle = TF_NewTensor(bytes, shape, dtype); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
protected unsafe void InitTensor(Array array, Shape? shape = null) | |||
{ | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
_isCreatedInGraphMode = !tf.executing_eagerly(); | |||
shape = shape ?? array.GetShape(); | |||
var dtype = array.GetDataType(); | |||
@@ -17,11 +17,8 @@ | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Globalization; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
@@ -97,12 +94,9 @@ namespace Tensorflow | |||
/// </summary> | |||
public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; | |||
protected bool isCreatedInGraphMode; | |||
protected bool _isCreatedInGraphMode; | |||
public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||
public bool IsSparseTensor => this is SparseTensor; | |||
public Tensor TensorShape => tf.shape(this); | |||
public bool IsCreatedInGraphMode => _isCreatedInGraphMode; | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
@@ -157,7 +151,6 @@ namespace Tensorflow | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public KerasHistory KerasHistory { get; set; } | |||
public Tensor KerasMask { get; set; } | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
@@ -383,6 +383,9 @@ namespace Tensorflow | |||
public static void reset_uid() | |||
{ | |||
uid_number = -1; | |||
graph_uid_number = -1; | |||
uid_number_for_function = 0; | |||
uid_number_for_layer = 0; | |||
} | |||
public static void colocate_with(bool ignore_existing = false) | |||
@@ -126,9 +126,10 @@ namespace Tensorflow.Keras | |||
PER_GRAPH_LAYER_NAME_UIDS.Clear(); | |||
_CURRENT_SCRATCH_GRAPH = null; | |||
_GRAPH = null; | |||
ops.set_default_session(tf.Session(ops.get_default_graph())); | |||
tf.enable_eager_execution(); | |||
tf.Runner.ClearEagerOperationMap(); | |||
GC.Collect(); | |||
GC.WaitForPendingFinalizers(); | |||
@@ -12,11 +12,6 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
public partial class Functional : Model | |||
{ | |||
Shape _build_input_shape; | |||
bool _compute_output_and_mask_jointly; | |||
bool _expects_training_arg; | |||
bool _expects_mask_arg; | |||
bool _autocast; | |||
List<ILayer> _output_layers; | |||
List<ILayer> _input_layers; | |||
List<KerasHistory> _input_coordinates; | |||
@@ -49,12 +44,6 @@ namespace Tensorflow.Keras.Engine | |||
this.inputs = inputs; | |||
this.outputs = outputs; | |||
built = true; | |||
_build_input_shape = inputs.shape; | |||
_compute_output_and_mask_jointly = true; | |||
_expects_training_arg = true; | |||
_expects_mask_arg = true; | |||
// A graph network does not autocast inputs, as its layers will cast them instead. | |||
_autocast = false; | |||
if (outputs.Any(x => x.KerasHistory == null)) | |||
base_layer_utils.create_keras_history(outputs); | |||
@@ -303,23 +292,11 @@ namespace Tensorflow.Keras.Engine | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
return run_internal_graph(inputs, training.Value); | |||
} | |||
Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||
{ | |||
if (mask == null) | |||
{ | |||
Tensor[] masks = new Tensor[inputs.Count()]; | |||
foreach (var (i, input_t) in enumerate(inputs)) | |||
input_t.KerasMask = masks[i]; | |||
} | |||
var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | |||
// map input values | |||
foreach (var (x, y) in zip(this.inputs, inputs)) | |||
{ | |||
var y1 = conform_to_reference_input(y, x); | |||
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1)); | |||
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y)); | |||
} | |||
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); | |||
@@ -336,11 +313,11 @@ namespace Tensorflow.Keras.Engine | |||
var layer_inputs = node.MapArguments(tensor_dict); | |||
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | |||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
var outputs = node.Layer.Apply(layer_inputs, is_training: training ?? false); | |||
foreach (var output in outputs.Where(x => x != null)) | |||
tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); | |||
// Update tensor_dict for next input | |||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs)) | |||
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | |||
} | |||
} | |||
@@ -352,10 +329,5 @@ namespace Tensorflow.Keras.Engine | |||
return output_tensors; | |||
} | |||
Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||
{ | |||
return tensor; | |||
} | |||
} | |||
} |
@@ -11,7 +11,6 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
List<(IVariableV1, NDArray)> LoadedWeights; | |||
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) | |||
{ | |||
long fileId = Hdf5.OpenFile(filepath, true); | |||
@@ -31,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
throw new NotImplementedException(""); | |||
else | |||
{ | |||
LoadedWeights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||
hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||
Hdf5.CloseFile(fileId); | |||
} | |||
} | |||
@@ -33,20 +33,13 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
NodeArgs args; | |||
public int[] node_indices; | |||
public int[] tensor_indices; | |||
public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | |||
public Tensors Outputs => args.Outputs; | |||
public Shape[] input_shapes; | |||
public Shape[] output_shapes; | |||
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>(); | |||
ILayer _layer; | |||
public ILayer Layer => _layer; | |||
public bool is_input => args.InputTensors == null; | |||
public long[] FlatInputIds { get; set; } | |||
public long[] FlatOutputIds { get; set; } | |||
bool _single_positional_tensor_passed => KerasInputs.Count() == 1; | |||
Dictionary<int, long> _keras_inputs_ids_and_indices = new Dictionary<int, long>(); | |||
public INode[] ParentNodes | |||
{ | |||
get | |||
@@ -74,9 +67,6 @@ namespace Tensorflow.Keras.Engine | |||
if (args.InputTensors != null) | |||
KerasInputs.AddRange(args.InputTensors); | |||
foreach (var (i, ele) in enumerate(KerasInputs)) | |||
_keras_inputs_ids_and_indices[i] = ele.Id; | |||
// Wire up Node to Layers. | |||
layer.InboundNodes.Add(this); | |||
@@ -93,10 +83,6 @@ namespace Tensorflow.Keras.Engine | |||
var node_index = layer.InboundNodes.Count - 1; | |||
foreach (var (i, tensor) in enumerate(Outputs)) | |||
tensor.KerasHistory = new KerasHistory(layer, node_index, i); | |||
// Cached for performance. | |||
FlatInputIds = KerasInputs.Select(x => x.Id).ToArray(); | |||
FlatOutputIds = Outputs.Select(x => x.Id).ToArray(); | |||
} | |||
/// <summary> | |||
@@ -106,16 +92,16 @@ namespace Tensorflow.Keras.Engine | |||
/// <returns></returns> | |||
public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict) | |||
{ | |||
if (_single_positional_tensor_passed) | |||
if (KerasInputs.Count() == 1) | |||
{ | |||
var kt_id = _keras_inputs_ids_and_indices[0]; | |||
var kt_id = KerasInputs[0].Id; | |||
return tensor_dict[kt_id].Dequeue(); | |||
} | |||
else | |||
{ | |||
var flat_arguments = KerasInputs.Select(x => x).ToArray(); | |||
foreach (var (kt_index, kt_id) in enumerate(_keras_inputs_ids_and_indices)) | |||
flat_arguments[kt_index] = tensor_dict[kt_id].Dequeue(); | |||
foreach (var (kt_index, kt) in enumerate(KerasInputs)) | |||
flat_arguments[kt_index] = tensor_dict[kt.Id].Dequeue(); | |||
return flat_arguments; | |||
} | |||
@@ -3,12 +3,10 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using HDF.PInvoke; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Keras.Engine; | |||
using HDF5CSharp; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Saving | |||
{ | |||
public class hdf5_format | |||
@@ -82,7 +80,7 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||
{ | |||
string original_keras_version = "2.5.0"; | |||
string original_backend = null; | |||
@@ -158,7 +156,6 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
keras.backend.batch_set_value(weight_value_tuples); | |||
return weight_value_tuples; | |||
} | |||
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) | |||
@@ -63,7 +63,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.139" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="13.0.1" /> | |||
<PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" /> | |||
<PackageReference Include="SharpZipLib" Version="1.3.2" /> | |||
<PackageReference Include="SharpZipLib" Version="1.3.3" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||