Browse Source

Fix epoch bug for dataset. #666

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
615105ca2b
6 changed files with 48 additions and 44 deletions
  1. +34
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
  2. +4
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  3. +3
    -19
      src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs
  4. +1
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
  5. +4
    -21
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  6. +2
    -2
      src/TensorFlowNET.Keras/Utils/Web.cs

+ 34
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine.DataAdapters
{
public abstract class DataAdapter
{
protected DataAdapterArgs args;
protected IDatasetV2 dataset;

public virtual bool CanHandle(Tensor x, Tensor y = null)
=> throw new NotImplementedException();

public virtual IDatasetV2 GetDataset()
=> dataset;

public virtual int GetSize()
=> throw new NotImplementedException("");

public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
{
if (y.TensorShape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
return (x, y);
}

public virtual bool ShouldRecreateIterator()
{
return true;
}
}
}

+ 4
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -91,12 +91,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters

public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
using var ownedIterator = new OwnedIterator(_dataset);
var data_iterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
yield return (epoch, ownedIterator);
if (_adapter.ShouldRecreateIterator())
data_iterator = new OwnedIterator(_dataset);
yield return (epoch, data_iterator);
}
}



+ 3
- 19
src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs View File

@@ -5,31 +5,15 @@ using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine.DataAdapters
{
public class DatasetAdapter : IDataAdapter
public class DatasetAdapter : DataAdapter, IDataAdapter
{
DataAdapterArgs args;
IDatasetV2 _dataset => args.Dataset;
public DatasetAdapter(DataAdapterArgs args)
{
this.args = args;
dataset = args.Dataset;
}

public bool CanHandle(Tensor x, Tensor y = null)
{
throw new NotImplementedException();
}

public IDatasetV2 GetDataset()
=> _dataset;

public int GetSize()
public override int GetSize()
=> -1;

public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
{
if (y.TensorShape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
return (x, y);
}
}
}

+ 1
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs View File

@@ -17,5 +17,6 @@
IDatasetV2 GetDataset();
int GetSize();
(Tensor, Tensor) Expand1d(Tensor x, Tensor y);
bool ShouldRecreateIterator();
}
}

+ 4
- 21
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters
/// <summary>
/// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.
/// </summary>
public class TensorLikeDataAdapter : IDataAdapter
public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
{
DataAdapterArgs args;
int _size;
int _batch_size;
int num_samples;
int num_full_batches;
IDatasetV2 _dataset;

public TensorLikeDataAdapter(DataAdapterArgs args)
{
@@ -31,7 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
indices_dataset = indices_dataset.repeat();
indices_dataset = indices_dataset.map(permutation).prefetch(1);
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
_dataset = slice_inputs(indices_dataset, args.X, args.Y);
dataset = slice_inputs(indices_dataset, args.X, args.Y);
}

Tensor permutation(Tensor tensor)
@@ -73,26 +71,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return dataset;
}

public bool CanHandle(Tensor x, Tensor y = null)
{
throw new NotImplementedException();
}

void _process_tensorlike()
{
}

public IDatasetV2 GetDataset()
=> _dataset;

public int GetSize()
public override int GetSize()
=> _size;

public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
void _process_tensorlike()
{
if (y.TensorShape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
return (x, y);
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Utils/Web.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Utils
}

var wc = new WebClient();
Console.WriteLine($"Downloading {relativeFilePath}");
Console.WriteLine($"Downloading from {url}");
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
while (!download.IsCompleted)
{
@@ -49,7 +49,7 @@ namespace Tensorflow.Keras.Utils
Console.Write(".");
}
Console.WriteLine("");
Console.WriteLine($"Downloaded {relativeFilePath}");
Console.WriteLine($"Downloaded to {relativeFilePath}");

return true;
}


Loading…
Cancel
Save