From 615105ca2b06ad56da3407c5574c979b080059bc Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 19 Dec 2020 15:48:46 -0600 Subject: [PATCH] Fix epoch bug for dataset. #666 --- .../Engine/DataAdapters/DataAdapter.cs | 34 +++++++++++++++++++ .../Engine/DataAdapters/DataHandler.cs | 6 ++-- .../Engine/DataAdapters/DatasetAdapter.cs | 22 ++---------- .../Engine/DataAdapters/IDataAdapter.cs | 1 + .../DataAdapters/TensorLikeDataAdapter.cs | 25 +++----------- src/TensorFlowNET.Keras/Utils/Web.cs | 4 +-- 6 files changed, 48 insertions(+), 44 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs new file mode 100644 index 00000000..1a179854 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + public abstract class DataAdapter + { + protected DataAdapterArgs args; + protected IDatasetV2 dataset; + + public virtual bool CanHandle(Tensor x, Tensor y = null) + => throw new NotImplementedException(); + + public virtual IDatasetV2 GetDataset() + => dataset; + + public virtual int GetSize() + => throw new NotImplementedException(""); + + public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + { + if (y.TensorShape.ndim == 1) + y = array_ops.expand_dims(y, axis: -1); + return (x, y); + } + + public virtual bool ShouldRecreateIterator() + { + return true; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 1bcb9c1d..950d2c98 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -91,12 +91,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters public IEnumerable<(int, OwnedIterator)> enumerate_epochs() { - using var ownedIterator = new OwnedIterator(_dataset); + var data_iterator = new OwnedIterator(_dataset); foreach (var epoch in range(_initial_epoch, _epochs)) { if (_insufficient_data) break; - yield return (epoch, ownedIterator); + if (_adapter.ShouldRecreateIterator()) + data_iterator = new OwnedIterator(_dataset); + yield return (epoch, data_iterator); } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs index d5f9613f..29b0e58b 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs @@ -5,31 +5,15 @@ using Tensorflow.Keras.ArgsDefinition; namespace Tensorflow.Keras.Engine.DataAdapters { - public class DatasetAdapter : IDataAdapter + public class DatasetAdapter : DataAdapter, IDataAdapter { - DataAdapterArgs args; - IDatasetV2 _dataset => args.Dataset; public DatasetAdapter(DataAdapterArgs args) { this.args = args; + dataset = args.Dataset; } - public bool CanHandle(Tensor x, Tensor y = null) - { - throw new NotImplementedException(); - } - - public IDatasetV2 GetDataset() - => _dataset; - - public int GetSize() + public override int GetSize() => -1; - - public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) - { - if (y.TensorShape.ndim == 1) - y = array_ops.expand_dims(y, axis: -1); - return (x, y); - } } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs index 587ebfe0..df414b9f 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs @@ -17,5 +17,6 @@ IDatasetV2 GetDataset(); int GetSize(); (Tensor, Tensor) Expand1d(Tensor x, Tensor y); + bool ShouldRecreateIterator(); } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index ecf5cbf9..0e8f74c3 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters /// /// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. /// - public class TensorLikeDataAdapter : IDataAdapter + public class TensorLikeDataAdapter : DataAdapter, IDataAdapter { - DataAdapterArgs args; int _size; int _batch_size; int num_samples; int num_full_batches; - IDatasetV2 _dataset; public TensorLikeDataAdapter(DataAdapterArgs args) { @@ -31,7 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters indices_dataset = indices_dataset.repeat(); indices_dataset = indices_dataset.map(permutation).prefetch(1); indices_dataset = indices_dataset.flat_map(slice_batch_indices); - _dataset = slice_inputs(indices_dataset, args.X, args.Y); + dataset = slice_inputs(indices_dataset, args.X, args.Y); } Tensor permutation(Tensor tensor) @@ -73,26 +71,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters return dataset; } - public bool CanHandle(Tensor x, Tensor y = null) - { - throw new NotImplementedException(); - } - - void _process_tensorlike() - { - } - - public IDatasetV2 GetDataset() - => _dataset; - - public int GetSize() + public override int GetSize() => _size; - public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + void _process_tensorlike() { - if (y.TensorShape.ndim == 1) - y = array_ops.expand_dims(y, axis: -1); - return (x, y); } } } diff --git a/src/TensorFlowNET.Keras/Utils/Web.cs b/src/TensorFlowNET.Keras/Utils/Web.cs index 839b6470..4e9f09d9 100644 --- a/src/TensorFlowNET.Keras/Utils/Web.cs +++ b/src/TensorFlowNET.Keras/Utils/Web.cs @@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Utils } var wc = new WebClient(); - Console.WriteLine($"Downloading {relativeFilePath}"); + Console.WriteLine($"Downloading from {url}"); var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); while (!download.IsCompleted) { @@ -49,7 +49,7 @@ namespace Tensorflow.Keras.Utils Console.Write("."); } Console.WriteLine(""); - Console.WriteLine($"Downloaded {relativeFilePath}"); + Console.WriteLine($"Downloaded to {relativeFilePath}"); return true; }