@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import os
from typing import Mapping, Sequence, Union
@@ -8,6 +9,7 @@ import pandas as pd
import pyarrow as pa
from datasets.info import DatasetInfo
from datasets.naming import camelcase_to_snakecase
from datasets.packaged_modules import _EXTENSION_TO_MODULE as exts
from datasets.packaged_modules import csv
from datasets.utils.filelock import FileLock
@@ -190,8 +192,54 @@ class TaskSpecificDatasetBuilder(MsCsvDatasetBuilder):
class ExternalDataset(object):
def __init__(self, split_path_dict, config_kwargs):
config_kwargs.update({'split_config': split_path_dict})
self.config_kwargs = config_kwargs
self.split_path_dict = split_path_dict
self.config_kwargs = copy.deepcopy(config_kwargs)
self.config_kwargs.update({'split_config': split_path_dict})
self.ext_dataset = None
self.split_data_files = {k: [] for k, _ in split_path_dict.items()}
file_ext = ''
for split_name, split_dir in split_path_dict.items():
if os.path.isdir(split_dir):
split_file_names = os.listdir(split_dir)
set_files_exts = set([
os.path.splitext(file_name)[-1].strip('.')
for file_name in split_file_names
])
# ensure these files have same extensions
if len(set_files_exts) != 1:
supported_exts = ','.join(exts.keys())
logger.error(
f'Split-{split_name} has been ignored, please flatten your folder structure, '
f'and make sure these files have same extensions. '
f'Supported extensions: {supported_exts} .')
continue
file_ext = list(set_files_exts)[0]
split_file_paths = [
os.path.join(split_dir, file_name)
for file_name in split_file_names
]
self.split_data_files[split_name] = split_file_paths
if file_ext and file_ext in exts:
file_ext = exts.get(file_ext)
self.ext_dataset = datasets.load_dataset(
file_ext, data_files=self.split_data_files, **config_kwargs)
def __len__(self):
return len(self.config_kwargs['split_config'])
return len(self.split_path_dict
) if not self.ext_dataset else self.ext_dataset.__len__()
def __getitem__(self, item):
if not self.ext_dataset:
return self.split_path_dict.get(item)
else:
return self.ext_dataset.__getitem__(item)
def __iter__(self):
if not self.ext_dataset:
for k, v in self.split_path_dict.items():
yield k, v
else:
for k, v in self.ext_dataset.items():
yield k, v