|
|
|
@@ -91,11 +91,14 @@ void Sampler::Print(std::ostream &out, bool show_all) const { |
|
|
|
Status Sampler::GetAllIdsThenReset(py::array *data) { |
|
|
|
std::unique_ptr<DataBuffer> db; |
|
|
|
std::shared_ptr<Tensor> sample_ids; |
|
|
|
TensorRow sample_row; |
|
|
|
|
|
|
|
// A call to derived class to get sample ids wrapped inside a buffer |
|
|
|
RETURN_IF_NOT_OK(GetNextSample(&db)); |
|
|
|
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch |
|
|
|
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); |
|
|
|
RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); |
|
|
|
sample_ids = sample_row[0]; |
|
|
|
|
|
|
|
// check this buffer is not a ctrl buffer |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); |
|
|
|
{ |
|
|
|
|