|
|
@@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler): |
|
|
|
|
|
|
|
@property |
|
|
|
def num_samples(self): |
|
|
|
return getattr(self.dataset, 'total_len', len(self.dataset)) |
|
|
|
""" |
|
|
|
返回样本的总数 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
total_len = getattr(self.dataset, 'total_len', None) |
|
|
|
if not isinstance(total_len, int): |
|
|
|
total_len = len(self.dataset) |
|
|
|
return total_len |
|
|
|
|
|
|
|
def __len__(self)->int: |
|
|
|
""" |
|
|
@@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): |
|
|
|
|
|
|
|
@property |
|
|
|
def num_samples(self): |
|
|
|
return getattr(self.dataset, 'total_len', len(self.dataset)) |
|
|
|
""" |
|
|
|
返回样本的总数 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
total_len = getattr(self.dataset, 'total_len', None) |
|
|
|
if not isinstance(total_len, int): |
|
|
|
total_len = len(self.dataset) |
|
|
|
return total_len |
|
|
|
|
|
|
|
def __len__(self)->int: |
|
|
|
""" |
|
|
|