diff --git a/reproduction/utils.py b/reproduction/utils.py index 26b2014c..4f0d021e 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -13,22 +13,30 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: } 如果paths为不合法的,将直接进行raise相应的错误 - :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, - test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ if isinstance(paths, str): if os.path.isfile(paths): return {'train': paths} elif os.path.isdir(paths): - train_fp = os.path.join(paths, 'train.txt') - if not os.path.isfile(train_fp): - raise FileNotFoundError(f"train.txt is not found in folder {paths}.") - files = {'train': train_fp} - for filename in ['dev.txt', 'test.txt']: - fp = os.path.join(paths, filename) - if os.path.isfile(fp): - files[filename.split('.')[0]] = fp + filenames = os.listdir(paths) + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + files[path_pair[0]] = os.path.join(paths, path_pair[1]) return files else: raise FileNotFoundError(f"{paths} is not a valid file path.")