@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | |||||
{ | |||||
public abstract class DataAdapter | |||||
{ | |||||
protected DataAdapterArgs args; | |||||
protected IDatasetV2 dataset; | |||||
public virtual bool CanHandle(Tensor x, Tensor y = null) | |||||
=> throw new NotImplementedException(); | |||||
public virtual IDatasetV2 GetDataset() | |||||
=> dataset; | |||||
public virtual int GetSize() | |||||
=> throw new NotImplementedException(""); | |||||
public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
{ | |||||
if (y.TensorShape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
return (x, y); | |||||
} | |||||
public virtual bool ShouldRecreateIterator() | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -91,12 +91,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | ||||
{ | { | ||||
using var ownedIterator = new OwnedIterator(_dataset); | |||||
var data_iterator = new OwnedIterator(_dataset); | |||||
foreach (var epoch in range(_initial_epoch, _epochs)) | foreach (var epoch in range(_initial_epoch, _epochs)) | ||||
{ | { | ||||
if (_insufficient_data) | if (_insufficient_data) | ||||
break; | break; | ||||
yield return (epoch, ownedIterator); | |||||
if (_adapter.ShouldRecreateIterator()) | |||||
data_iterator = new OwnedIterator(_dataset); | |||||
yield return (epoch, data_iterator); | |||||
} | } | ||||
} | } | ||||
@@ -5,31 +5,15 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
{ | { | ||||
public class DatasetAdapter : IDataAdapter | |||||
public class DatasetAdapter : DataAdapter, IDataAdapter | |||||
{ | { | ||||
DataAdapterArgs args; | |||||
IDatasetV2 _dataset => args.Dataset; | |||||
public DatasetAdapter(DataAdapterArgs args) | public DatasetAdapter(DataAdapterArgs args) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
dataset = args.Dataset; | |||||
} | } | ||||
public bool CanHandle(Tensor x, Tensor y = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public IDatasetV2 GetDataset() | |||||
=> _dataset; | |||||
public int GetSize() | |||||
public override int GetSize() | |||||
=> -1; | => -1; | ||||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
{ | |||||
if (y.TensorShape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
return (x, y); | |||||
} | |||||
} | } | ||||
} | } |
@@ -17,5 +17,6 @@ | |||||
IDatasetV2 GetDataset(); | IDatasetV2 GetDataset(); | ||||
int GetSize(); | int GetSize(); | ||||
(Tensor, Tensor) Expand1d(Tensor x, Tensor y); | (Tensor, Tensor) Expand1d(Tensor x, Tensor y); | ||||
bool ShouldRecreateIterator(); | |||||
} | } | ||||
} | } |
@@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
/// <summary> | /// <summary> | ||||
/// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. | /// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. | ||||
/// </summary> | /// </summary> | ||||
public class TensorLikeDataAdapter : IDataAdapter | |||||
public class TensorLikeDataAdapter : DataAdapter, IDataAdapter | |||||
{ | { | ||||
DataAdapterArgs args; | |||||
int _size; | int _size; | ||||
int _batch_size; | int _batch_size; | ||||
int num_samples; | int num_samples; | ||||
int num_full_batches; | int num_full_batches; | ||||
IDatasetV2 _dataset; | |||||
public TensorLikeDataAdapter(DataAdapterArgs args) | public TensorLikeDataAdapter(DataAdapterArgs args) | ||||
{ | { | ||||
@@ -31,7 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
indices_dataset = indices_dataset.repeat(); | indices_dataset = indices_dataset.repeat(); | ||||
indices_dataset = indices_dataset.map(permutation).prefetch(1); | indices_dataset = indices_dataset.map(permutation).prefetch(1); | ||||
indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
_dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
} | } | ||||
Tensor permutation(Tensor tensor) | Tensor permutation(Tensor tensor) | ||||
@@ -73,26 +71,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
return dataset; | return dataset; | ||||
} | } | ||||
public bool CanHandle(Tensor x, Tensor y = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
void _process_tensorlike() | |||||
{ | |||||
} | |||||
public IDatasetV2 GetDataset() | |||||
=> _dataset; | |||||
public int GetSize() | |||||
public override int GetSize() | |||||
=> _size; | => _size; | ||||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
void _process_tensorlike() | |||||
{ | { | ||||
if (y.TensorShape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
return (x, y); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Utils | |||||
} | } | ||||
var wc = new WebClient(); | var wc = new WebClient(); | ||||
Console.WriteLine($"Downloading {relativeFilePath}"); | |||||
Console.WriteLine($"Downloading from {url}"); | |||||
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); | var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); | ||||
while (!download.IsCompleted) | while (!download.IsCompleted) | ||||
{ | { | ||||
@@ -49,7 +49,7 @@ namespace Tensorflow.Keras.Utils | |||||
Console.Write("."); | Console.Write("."); | ||||
} | } | ||||
Console.WriteLine(""); | Console.WriteLine(""); | ||||
Console.WriteLine($"Downloaded {relativeFilePath}"); | |||||
Console.WriteLine($"Downloaded to {relativeFilePath}"); | |||||
return true; | return true; | ||||
} | } | ||||