|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import tensorflow as tf
-
- __all__ = [
- 'Apply',
- 'Batch',
- 'Concat',
- 'CsvDataset',
- 'Filter',
- 'Flat_map',
- 'FromGenerator',
- 'FromSlices',
- 'Map',
- 'Prefetch',
- 'Repeat',
- 'Shuffle',
- 'Skip',
- 'Take',
- 'TextFlieDataset',
- 'TFRecordDataset',
- 'Zip',
- 'Dataloader',
- ]
-
-
- def Apply(dataset, transformation_func):
- """Applies a transformation function to this dataset.
- `apply` enables chaining of custom `Dataset` transformations, which are
- represented as functions that take one `Dataset` argument and return a
- transformed `Dataset`.
- >>> dataset = tf.data.Dataset.range(100)
- >>> def dataset_fn(dataset):
- ... return dataset.filter(lambda x: x < 5)
- >>> dataset = dataset.apply(dataset_fn)
- >>> list(dataset.as_numpy_iterator())
- [0, 1, 2, 3, 4]
- Args:
- transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
- Returns:
- Dataset: The `Dataset` returned by applying `transformation_func` to this
- dataset.
- """
- return dataset.apply(transformation_func)
-
-
- def Batch(dataset, batch_size, drop_remainder=False):
- '''
-
- Parameters
- ----------
- dataset
- batch_size
- drop_remainder
-
- Returns
- -------
-
- '''
- return dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
-
-
- def Concat(dataset_1, dataset_2):
-
- return dataset_1.concatenate(dataset_2)
-
-
- def CsvDataset(
- file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None,
- field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
- shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
- num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None,
- shard_id=None, cache=None
- ):
- """Reads CSV files into a dataset.
- Reads CSV files into a dataset, where each element is a (features, labels)
- tuple that corresponds to a batch of CSV rows. The features dictionary
- maps feature column names to `Tensor`s containing the corresponding
- feature data, and labels is a `Tensor` containing the batch's label data.
- """
- return tf.data.experimental.make_csv_dataset(
- file_pattern, batch_size, column_names=None, column_defaults=None, label_name=None, select_columns=None,
- field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
- shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
- num_rows_for_inference=100, compression_type=None, ignore_errors=False
- )
-
-
- def Filter(dataset, predicate):
- '''
- Filters this dataset according to predicate.
- Parameters
- ----------
- dataset :
- A dataset
- predicate :
- A function mapping a dataset element to a boolean.
- Returns :
- The Dataset containing the elements of this dataset for which predicate is True.
- -------
-
- '''
- return dataset.filter(predicate)
-
-
- def Flat_map(dataset, map_func):
- '''
- Maps map_func across this dataset and flattens the result.
- Parameters
- ----------
- dataset:
- A dataset
- map_func
- A function mapping a dataset element to a dataset.
- Returns
- A Dataset.
- -------
-
- '''
- return dataset.flat_map(map_func)
-
-
- def FromGenerator(
- generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None,
- num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
- python_multiprocessing=True
- ):
- """Creates a `Dataset` whose elements are generated by `generator`.
-
- generator:
- A callable object
- """
- return tf.data.Dataset.from_generator(generator, output_types, output_shapes=output_shapes, args=args)
-
-
- def FromSlices(
- tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None,
- shard_id=None
- ):
-
- return tf.data.Dataset.from_tensor_slices(tensor)
-
-
- def Map(
- dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None,
- num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None
- ):
- """ Maps map_func across the elements of this dataset.
-
- Parameters
- ----------
- dataset : DataFlow
- input DataFlow
- map_func : function
- A function mapping a dataset element to another dataset element.
- num_parallel_calls
-
- Returns
- -------
-
- """
- return dataset.map(map_func, num_parallel_calls=num_parallel_calls)
-
-
- def Prefetch(dataset, buffer_size):
- '''
- Creates a Dataset that prefetches elements from this dataset.
- Parameters
- ----------
- dataset: Dataflow
- A dataset
- buffer_size :
- A tf.int64 scalar tf.Tensor, representing the maximum number of elements that will be buffered when prefetching.
- Returns
- A Dataset
- -------
-
- '''
- return dataset.prefetch(buffer_size=buffer_size)
-
-
- def Repeat(dataset, count=None):
- return dataset.repeat(count=count)
-
-
- def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None):
- return dataset.shuffle(buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration)
-
-
- def Skip(dataset, count):
- '''
- Creates a Dataset that skips count elements from this dataset.
- Parameters
- ----------
- dataset:
- A dataset
- count:
- A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset.
- If count is greater than the size of this dataset, the new dataset will contain no elements.
- If count is -1, skips the entire dataset.
-
- Returns
- -------
-
- '''
- return dataset.skip(count)
-
-
- def Take(dataset, count):
- '''
- Creates a Dataset with at most count elements from this dataset.
- Parameters
- ----------
- dataset:
- A dataset
- count:
- A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset.
- If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
- Returns
- -------
-
- '''
- return dataset.take(count)
-
-
- def TextFlieDataset(
- filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None,
- num_shards=None, shard_id=None, cache=None
- ):
-
- return tf.data.TextLineDataset(filenames, compression_type, buffer_size, num_parallel_reads)
-
-
- def TFRecordDataset(
- filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None,
- num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None
- ):
-
- return tf.data.TFRecordDataset(filenames, compression_type, buffer_size, num_parallel_reads)
-
-
- def Zip(datasets):
- '''
- Creates a Dataset by zipping together the given datasets.
- Parameters
- ----------
- datasets:
- A tuple of datasets to be zipped together.
- Returns
- -------
-
- '''
- return tf.data.Dataset.zip(datasets)
-
-
- def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=1024):
-
- if shuffle:
- dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True)
-
- dataset = Batch(dataset, batch_size=batch_size, drop_remainder=drop_last)
- dataset = Prefetch(dataset, buffer_size=prefetch)
-
- return dataset
|