|
|
@@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
/// <summary> |
|
|
|
/// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. |
|
|
|
/// </summary> |
|
|
|
public class TensorLikeDataAdapter : IDataAdapter |
|
|
|
public class TensorLikeDataAdapter : DataAdapter, IDataAdapter |
|
|
|
{ |
|
|
|
DataAdapterArgs args; |
|
|
|
int _size; |
|
|
|
int _batch_size; |
|
|
|
int num_samples; |
|
|
|
int num_full_batches; |
|
|
|
IDatasetV2 _dataset; |
|
|
|
|
|
|
|
public TensorLikeDataAdapter(DataAdapterArgs args) |
|
|
|
{ |
|
|
@@ -31,7 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
indices_dataset = indices_dataset.repeat(); |
|
|
|
indices_dataset = indices_dataset.map(permutation).prefetch(1); |
|
|
|
indices_dataset = indices_dataset.flat_map(slice_batch_indices); |
|
|
|
_dataset = slice_inputs(indices_dataset, args.X, args.Y); |
|
|
|
dataset = slice_inputs(indices_dataset, args.X, args.Y); |
|
|
|
} |
|
|
|
|
|
|
|
Tensor permutation(Tensor tensor) |
|
|
@@ -73,26 +71,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
return dataset; |
|
|
|
} |
|
|
|
|
|
|
|
public bool CanHandle(Tensor x, Tensor y = null) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
} |
|
|
|
|
|
|
|
void _process_tensorlike() |
|
|
|
{ |
|
|
|
} |
|
|
|
|
|
|
|
public IDatasetV2 GetDataset() |
|
|
|
=> _dataset; |
|
|
|
|
|
|
|
public int GetSize() |
|
|
|
public override int GetSize() |
|
|
|
=> _size; |
|
|
|
|
|
|
|
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) |
|
|
|
void _process_tensorlike() |
|
|
|
{ |
|
|
|
if (y.TensorShape.ndim == 1) |
|
|
|
y = array_ops.expand_dims(y, axis: -1); |
|
|
|
return (x, y); |
|
|
|
} |
|
|
|
} |
|
|
|
} |