|
|
@@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: |
|
|
|
} |
|
|
|
如果paths为不合法的,将直接进行raise相应的错误 |
|
|
|
|
|
|
|
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, |
|
|
|
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 |
|
|
|
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 |
|
|
|
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if isinstance(paths, str): |
|
|
|
if os.path.isfile(paths): |
|
|
|
return {'train': paths} |
|
|
|
elif os.path.isdir(paths): |
|
|
|
train_fp = os.path.join(paths, 'train.txt') |
|
|
|
if not os.path.isfile(train_fp): |
|
|
|
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") |
|
|
|
files = {'train': train_fp} |
|
|
|
for filename in ['dev.txt', 'test.txt']: |
|
|
|
fp = os.path.join(paths, filename) |
|
|
|
if os.path.isfile(fp): |
|
|
|
files[filename.split('.')[0]] = fp |
|
|
|
filenames = os.listdir(paths) |
|
|
|
files = {} |
|
|
|
for filename in filenames: |
|
|
|
path_pair = None |
|
|
|
if 'train' in filename: |
|
|
|
path_pair = ('train', filename) |
|
|
|
if 'dev' in filename: |
|
|
|
if path_pair: |
|
|
|
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) |
|
|
|
path_pair = ('dev', filename) |
|
|
|
if 'test' in filename: |
|
|
|
if path_pair: |
|
|
|
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) |
|
|
|
path_pair = ('test', filename) |
|
|
|
if path_pair: |
|
|
|
files[path_pair[0]] = os.path.join(paths, path_pair[1]) |
|
|
|
return files |
|
|
|
else: |
|
|
|
raise FileNotFoundError(f"{paths} is not a valid file path.") |
|
|
|