diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs index 059aef24..2e5485a4 100644 --- a/src/TensorFlowNET.Core/Data/DatasetManager.cs +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -25,10 +25,16 @@ namespace Tensorflow public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) => new TensorSliceDataset(features, labels); + public IDatasetV2 from_tensor_slices(NDArray array) + => new TensorSliceDataset(array); + public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64) => new RangeDataset(count, output_type: output_type); public IDatasetV2 range(int start, int stop, int step = 1, TF_DataType output_type = TF_DataType.TF_INT64) => new RangeDataset(stop, start: start, step: step, output_type: output_type); + + public IDatasetV2 zip(params IDatasetV2[] ds) + => new ZipDataset(ds); } } diff --git a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs index feea6d61..cbf9b847 100644 --- a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs +++ b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs @@ -11,6 +11,16 @@ namespace Tensorflow.Data { public class TensorSliceDataset : DatasetSource { + public TensorSliceDataset(NDArray array) + { + var element = tf.constant(array); + _tensors = new[] { element }; + var batched_spec = new[] { element.ToTensorSpec() }; + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); + } + public TensorSliceDataset(Tensor features, Tensor labels) { _tensors = new[] { features, labels }; diff --git a/src/TensorFlowNET.Core/Data/ZipDataset.cs b/src/TensorFlowNET.Core/Data/ZipDataset.cs new file mode 100644 index 00000000..b5d7e189 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ZipDataset.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class ZipDataset : DatasetV2 + { + dataset_ops ops = new dataset_ops(); + public ZipDataset(params IDatasetV2[] ds) + { + var input_datasets = ds.Select(x => x.variant_tensor).ToArray(); + structure = ds.Select(x => x.structure[0]).ToArray(); + variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RescalingArgs.cs new file mode 100644 index 00000000..edbbc642 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RescalingArgs.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RescalingArgs : LayerArgs + { + public float Scale { get; set; } + public float Offset { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index dcc6f600..9121efa8 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -72,7 +72,17 @@ namespace Tensorflow public class LayersApi { - public Layer Dense(int units, + public Rescaling Rescaling(float scale, + float offset = 0, + TensorShape input_shape = null) + => new Rescaling(new RescalingArgs + { + Scale = scale, + Offset = offset, + InputShape = input_shape + }); + + public Dense Dense(int units, Activation activation = null, TensorShape input_shape = null) => new Dense(new DenseArgs diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs new file mode 100644 index 00000000..7fddb07e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Multiply inputs by `scale` and adds `offset`. + /// + public class Rescaling : Layer + { + RescalingArgs args; + Tensor scale; + Tensor offset; + + public Rescaling(RescalingArgs args) : base(args) + { + this.args = args; + } + + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + { + scale = math_ops.cast(args.Scale, args.DType); + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.cs b/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.cs index e012d95f..73b77490 100644 --- a/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.cs +++ b/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.cs @@ -1,11 +1,20 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Preprocessings { public partial class DatasetUtils { - + public IDatasetV2 labels_to_dataset(int[] labels, string label_mode, int num_classes) + { + var label_ds = tf.data.Dataset.from_tensor_slices(labels); + if (label_mode == "binary") + throw new NotImplementedException(""); + else if(label_mode == "categorical") + throw new NotImplementedException(""); + return label_ds; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs b/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs index a0969ea0..a8c8d286 100644 --- a/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs +++ b/src/TensorFlowNET.Core/Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs @@ -31,8 +31,8 @@ namespace Tensorflow.Keras.Preprocessings else if (subset == "validation") { Console.WriteLine($"Using {num_val_samples} files for validation."); - samples = samples[samples.Length..]; - labels = labels[samples.Length..]; + samples = samples[(samples.Length - num_val_samples)..]; + labels = labels[(labels.Length - num_val_samples)..]; } else throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index 6f84e7b6..708d1660 100644 --- a/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Keras /// /// /// - public Tensor image_dataset_from_directory(string directory, + public IDatasetV2 image_dataset_from_directory(string directory, string labels = "inferred", string label_mode = "int", string[] class_names = null, @@ -52,8 +52,11 @@ namespace Tensorflow.Keras (image_paths, label_list) = tf.keras.preprocessing.dataset_utils.get_training_or_validation_split(image_paths, label_list, validation_split, subset); - paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); - throw new NotImplementedException(""); + var dataset = paths_and_labels_to_dataset(image_paths, image_size, num_channels, label_list, label_mode, class_name_list.Length, interpolation); + if (shuffle) + dataset = dataset.shuffle(batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + return dataset; } } } diff --git a/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs index 795f4e3f..e44c3534 100644 --- a/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs +++ b/src/TensorFlowNET.Core/Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs @@ -1,12 +1,14 @@ -using System; +using NumSharp; +using System; using System.Globalization; +using System.Threading.Tasks; using static Tensorflow.Binding; namespace Tensorflow.Keras { public partial class Preprocessing { - public Tensor paths_and_labels_to_dataset(string[] image_paths, + public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, TensorShape image_size, int num_channels, int[] labels, @@ -14,10 +16,29 @@ namespace Tensorflow.Keras int num_classes, string interpolation) { - foreach (var image_path in image_paths) - path_to_image(image_path, image_size, num_channels, interpolation); + Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels); + Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}"); + var data = np.zeros(shape, NPTypeCode.Float); - throw new NotImplementedException(""); + for (var i = 0; i < image_paths.Length; i++) + { + var image = path_to_image(image_paths[i], image_size, num_channels, interpolation); + data[i] = image.numpy(); + if (i % 100 == 0) + Console.WriteLine($"Filled {i}/{image_paths.Length} data into memory."); + } + + var img_ds = tf.data.Dataset.from_tensor_slices(data); + + if (label_mode == "int") + { + var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes); + img_ds = tf.data.Dataset.zip(img_ds, label_ds); + } + else + throw new NotImplementedException(""); + + return img_ds; } Tensor path_to_image(string path, TensorShape image_size, int num_channels, string interpolation) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index bdffcf2b..7dcbaedd 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -67,6 +67,14 @@ namespace Tensorflow public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) { + if (tf.executing_eagerly()) + { + if (condition == null) + throw new InvalidArgumentError(""); + + return null; + } + return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => { name = scope; @@ -86,7 +94,7 @@ namespace Tensorflow var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard"); - return guarded_assert == null ? null : guarded_assert[0].op; + return guarded_assert[0].op; }); } diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 26d614e9..74b203df 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -102,6 +102,28 @@ namespace Tensorflow throw new NotImplementedException(""); } + public Tensor zip_dataset(Tensor[] input_datasets, + TF_DataType[] output_types, + TensorShape[] output_shapes, + string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ZipDataset", name, + null, + new object[] + { + input_datasets, + "output_types", output_types, + "output_shapes", output_shapes + }); + return results[0]; + } + + throw new NotImplementedException(""); + } + public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size, Tensor seed, Tensor seed2, Tensor seed_generator, TF_DataType[] output_types, TensorShape[] output_shapes, diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 12d2f7b4..1603f5b1 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -53,6 +53,7 @@ namespace TensorFlowNET.UnitTest.Basics tf_with(new ops.NameScope("scope"), scope => { string name = scope; + var const1 = tf.constant(1.0); }); tf.compat.v1.disable_eager_execution();