|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import mindspore.dataset as ds
- import mindspore as ms
- from enum import Enum
- __all__ = [
- 'Apply',
- 'Batch',
- 'Concat',
- 'CsvDataset',
- 'Filter',
- 'Flat_map',
- 'FromGenerator',
- 'FromSlices',
- 'Map',
- 'Prefetch',
- 'Repeat',
- 'Shuffle',
- 'Skip',
- 'Take',
- 'TextFlieDataset',
- 'TFRecordDataset',
- 'Dataloader',
- ]
-
-
- class Shuffle(str, Enum):
- GLOBAL: str = "global"
- FILES: str = "file"
-
-
- def Apply(dataset, transformation_func):
-
- return dataset.apply(transformation_func)
-
-
- def Batch(
- dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, inut_columns=None,
- output_columns=None, column_order=None, pad_info=None
- ):
- '''
- Combine batch_size number of consecutive rows into batches.
- Parameters
- ----------
- dataset
- batch_size
- drop_remainder
- num_parallel_workers
- per_batch_map
- inut_columns
- output_columns
- column_order
- pad_info
-
- Returns
- -------
-
- '''
- return dataset.batch(
- batch_size=batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_parallel_workers,
- per_batch_map=per_batch_map, input_columns=inut_columns, output_columns=output_columns,
- column_order=column_order, pad_info=pad_info
- )
-
-
- def Concat(dataset_1, dataset_2):
-
- return dataset_1.concat(dataset_2)
-
-
- def CsvDataset(
- file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None,
- field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=Shuffle.GLOBAL,
- shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
- num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None,
- shard_id=None, cache=None
- ):
- """
- A source dataset that reads and parses comma-separated values (CSV) datasets.
-
- Examples:
- >>> import mindspore.dataset as dataset
- >>>
- >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
- >>> dataset = dataset.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
- """
- return ds.CSVDataset(
- dataset_files=file_pattern, field_delim=field_delim, column_defaults=column_defaults, column_names=column_names,
- num_samples=numples_samples, num_parallel_workers=num_parallel_reads, shuffle=shuffle, num_shards=num_shards,
- shard_id=shard_id, cache=cache
- )
-
-
- def Filter(dataset, predicate):
-
- return dataset.filter(predicate)
-
-
- def Flat_map(dataset, map_func):
-
- return dataset.flat_map(map_func)
-
-
- def FromGenerator(
- generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None,
- num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
- python_multiprocessing=True
- ):
-
- return ds.GeneratorDataset(
- source=generator, column_names=column_names, column_types=column_types, schema=schema, num_samples=num_samples,
- num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_shards=num_shards,
- shard_id=shard_id, python_multiprocessing=python_multiprocessing
- )
-
-
- def FromSlices(
- tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None,
- shard_id=None
- ):
-
- return ds.NumpySlicesDataset(
- data=tensor, column_names=column_names, num_samples=num_samples, num_parallel_workers=num_parallel_workers,
- shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id
- )
-
-
- def Map(
- dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None,
- num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None
- ):
- """ Maps map_func across the elements of this dataset.
-
- Parameters
- ----------
- dataset : DataFlow
- input DataFlow
- map_func : function
- A function mapping a dataset element to another dataset element.
- num_parallel_calls
-
- Returns
- -------
-
- """
- return dataset.map(
- operations=map_func, input_columns=input_columns, output_columns=output_columns, column_order=column_order,
- num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing, cache=cache,
- callbacks=callbacks
- )
-
-
- def Prefetch(dataset, buffer_size):
-
- batch_size = dataset.get_batch_size()
- prefetch_size = batch_size * buffer_size
-
- return dataset.config.set_prefetch_size(prefetch_size)
-
-
- def Repeat(dataset, count=None):
-
- return dataset.repeat(count)
-
-
- def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None):
-
- #dataset.config.set_seed(seed)
-
- return dataset.shuffle(buffer_size)
-
-
- def Skip(dataset, count):
- '''
- Creates a Dataset that skips count elements from this dataset.
- Parameters
- ----------
- dataset:
- A dataset
- count:
- A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset.
-
-
- Returns
- -------
-
- '''
- return dataset.skip(count)
-
-
- def Take(dataset, count):
- '''
- Creates a Dataset with at most count elements from this dataset.
- Parameters
- ----------
- dataset:
- A dataset
- count:
- A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset.
- If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
- Returns
- -------
-
- '''
- return dataset.take(count)
-
-
- def TextFlieDataset(
- filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None,
- num_shards=None, shard_id=None, cache=None
- ):
- """
- A source dataset that reads and parses datasets stored on disk in text format.
- The generated dataset has one column ['text'].
-
- Examples:
- >>> import mindspore.dataset as dataset
- >>>
- >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
- >>> dataset = dataset.TextFileDataset(dataset_files=dataset_files)
- """
- if shuffle is None:
- shuffle = Shuffle.GLOBAL
- return ds.TextFileDataset(
- dataset_files=filenames, num_samples=num_samples, num_parallel_workers=num_parallel_reads, shuffle=shuffle,
- num_shards=num_shards, shard_id=shard_id, cache=cache
- )
-
-
- def TFRecordDataset(
- filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None,
- num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None
- ):
- """
- A source dataset that reads and parses datasets stored on disk in TFData format.
-
- Examples:
- >>> import mindspore.dataset as dataset
- >>> import mindspore.common.dtype as mstype
- >>>
- >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
- >>>
- >>> # 1) Get all rows from dataset_files with no explicit schema
- >>> # The meta-data in the first row will be used as a schema.
- >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files)
- >>>
- >>> # 2) Get all rows from dataset_files with user-defined schema
- >>> schema = dataset.Schema()
- >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
- >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files, schema=schema)
- >>>
- >>> # 3) Get all rows from dataset_files with schema file "./schema.json"
- >>> tfdataset = dataset.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
- """
- if shuffle is None:
- shuffle = Shuffle.GLOBAL
- return ds.TFRecordDataset(
- dataset_files=filenames, schema=schema, columns_list=columns_list, num_samples=num_samples,
- num_parallel_workers=num_parallel_reads, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id,
- shard_equal_rows=shard_equal_rows, cache=cache
- )
-
-
- def Zip(datasets):
- '''
- Creates a Dataset by zipping together the given datasets.
- Parameters
- ----------
- datasets:
- A tuple of datasets to be zipped together.
- Returns
- -------
-
- '''
- return ds.zip(datasets)
-
-
- def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=0):
-
- if shuffle:
- dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size)
-
- dataset = Batch(dataset, batch_size=batch_size, drop_remainder=drop_last)
- dataset = Prefetch(dataset, buffer_size=prefetch)
-
- return dataset
|