|
|
@@ -41,12 +41,16 @@ def torch_default_data_collator(features): |
|
|
|
first['label'], torch.Tensor) else first['label'] |
|
|
|
# the msdataset return a 0-dimension np.array with a single value, the following part handle this. |
|
|
|
if isinstance(label, np.ndarray): |
|
|
|
src_dtype = label[()].dtype |
|
|
|
dtype = torch.long if label[( |
|
|
|
)].dtype == np.int64 else torch.float |
|
|
|
else: |
|
|
|
src_dtype = type(label) |
|
|
|
dtype = torch.long if isinstance(label, int) else torch.float |
|
|
|
# add dtype to np.array to fix "TypeError: can't convert np.ndarray of type numpy.object_" |
|
|
|
batch['labels'] = torch.tensor( |
|
|
|
np.array([f['label'] for f in features]), dtype=dtype) |
|
|
|
np.array([f['label'] for f in features], dtype=src_dtype), |
|
|
|
dtype=dtype) |
|
|
|
elif 'label_ids' in first and first['label_ids'] is not None: |
|
|
|
if isinstance(first['label_ids'], torch.Tensor): |
|
|
|
batch['labels'] = torch.stack( |
|
|
|