diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 874ad895..edb8a67f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ @@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): @property def num_samples(self): - return getattr(self.dataset, 'total_len', len(self.dataset)) + """ + 返回样本的总数 + + :return: + """ + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len def __len__(self)->int: """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index a1dc318e..fe38a808 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler): :return: """ - return getattr(self.dataset, 'total_len', len(self.dataset)) + total_len = getattr(self.dataset, 'total_len', None) + if not isinstance(total_len, int): + total_len = len(self.dataset) + return total_len class SequentialSampler(RandomSampler): """