Browse Source

1. add set_dataset_size for MindDataset 2. modify parameter dupe_factor from 5 to 10

tags/v0.3.1-alpha
jonyguo 5 years ago
parent
commit
488b74e92f
3 changed files with 25 additions and 15 deletions
  1. +1
    -1
      example/nlp_to_mindrecord/zhwiki/run.sh
  2. +1
    -1
      example/nlp_to_mindrecord/zhwiki/run_simple.sh
  3. +23
    -13
      mindspore/dataset/engine/datasets.py

+ 1
- 1
example/nlp_to_mindrecord/zhwiki/run.sh View File

@@ -83,7 +83,7 @@ for index in $(seq 0 $file_list_len); do
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 &
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then


+ 1
- 1
example/nlp_to_mindrecord/zhwiki/run_simple.sh View File

@@ -44,4 +44,4 @@ python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
--dupe_factor=10 # user defined

+ 23
- 13
mindspore/dataset/engine/datasets.py View File

@@ -2577,20 +2577,30 @@ class MindDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
if self._dataset_size is None:
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
return self._dataset_size

# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))

def is_shuffled(self):
if self.shuffle_option is None:


Loading…
Cancel
Save