Browse Source

更新data folder读取代码

tags/v0.4.10
yh_cc 5 years ago
parent
commit
2858ff84cf
1 changed files with 18 additions and 10 deletions
  1. +18
    -10
      reproduction/utils.py

+ 18
- 10
reproduction/utils.py View File

@@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
}
如果paths为不合法的,将直接进行raise相应的错误

:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt,
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
train_fp = os.path.join(paths, 'train.txt')
if not os.path.isfile(train_fp):
raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
files = {'train': train_fp}
for filename in ['dev.txt', 'test.txt']:
fp = os.path.join(paths, filename)
if os.path.isfile(fp):
files[filename.split('.')[0]] = fp
filenames = os.listdir(paths)
files = {}
for filename in filenames:
path_pair = None
if 'train' in filename:
path_pair = ('train', filename)
if 'dev' in filename:
if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1])
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")


Loading…
Cancel
Save