|
|
@@ -2577,20 +2577,30 @@ class MindDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self.load_dataset: |
|
|
|
|
|
dataset_file = [self.dataset_file] |
|
|
|
|
|
else: |
|
|
|
|
|
dataset_file = self.dataset_file |
|
|
|
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) |
|
|
|
|
|
if self.partitions is not None and self.partitions[0] > 0: |
|
|
|
|
|
if num_rows % self.partitions[0] == 0: |
|
|
|
|
|
num_rows = num_rows // self.partitions[0] |
|
|
|
|
|
|
|
|
if self._dataset_size is None: |
|
|
|
|
|
if self.load_dataset: |
|
|
|
|
|
dataset_file = [self.dataset_file] |
|
|
else: |
|
|
else: |
|
|
if self.num_padded > 0: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"Dataset size plus number of padded samples is not divisible by number of shards.") |
|
|
|
|
|
num_rows = num_rows // self.partitions[0] + 1 |
|
|
|
|
|
return num_rows |
|
|
|
|
|
|
|
|
dataset_file = self.dataset_file |
|
|
|
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) |
|
|
|
|
|
if self.partitions is not None and self.partitions[0] > 0: |
|
|
|
|
|
if num_rows % self.partitions[0] == 0: |
|
|
|
|
|
num_rows = num_rows // self.partitions[0] |
|
|
|
|
|
else: |
|
|
|
|
|
if self.num_padded > 0: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"Dataset size plus number of padded samples is not divisible by number of shards.") |
|
|
|
|
|
num_rows = num_rows // self.partitions[0] + 1 |
|
|
|
|
|
return num_rows |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
|
|
|
# manually set dataset_size as a tempoary solution. |
|
|
|
|
|
def set_dataset_size(self, value): |
|
|
|
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") |
|
|
|
|
|
if value >= 0: |
|
|
|
|
|
self._dataset_size = value |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError('set dataset_size with negative value {}'.format(value)) |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_option is None: |
|
|
if self.shuffle_option is None: |
|
|
|