diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 887a7abe..f75b6c90 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -737,7 +737,7 @@ def _pred_topk(y_prob, k=1): class ExtractiveQAMetric(MetricBase): - """ + r""" 别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` 抽取式QA(如SQuAD)的metric. diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 3f5759d6..21dcefb0 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -33,8 +33,8 @@ class MatchingLoader(DataSetLoader): to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, cut_text: int = None, get_index=True, auto_pad_length: int=None, auto_pad_token: str='', set_input: Union[list, str, bool]=True, - set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, - extra_split: List[str]=['-'], ) -> DataBundle: + set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, + extra_split: List[str]=List['-'], ) -> DataBundle: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index ecbabd49..df46b47f 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -114,6 +114,9 @@ class SST2Loader(CSVLoader): def _load(self, path: str) -> DataSet: ds = super(SST2Loader, self)._load(path) + for k, v in self.field.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) print("all count:", len(ds)) return ds diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 04970cb3..cb762eb7 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -17,6 +17,10 @@ PRETRAINED_BERT_MODEL_DIR = { 'en-large-uncased': 'bert-large-uncased-20939f45.zip', 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', + 'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', + 'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', + 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', + 'cn': 'bert-base-chinese-29d0a84a.zip', 'cn-base': 'bert-base-chinese-29d0a84a.zip', @@ -68,6 +72,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: "unable to parse {} as a URL or as a local path".format(url_or_filename) ) + def get_filepath(filepath): """ 如果filepath中只有一个文件,则直接返回对应的全路径 @@ -82,6 +87,7 @@ def get_filepath(filepath): return filepath return filepath + def get_defalt_path(): """ 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 @@ -98,6 +104,7 @@ def get_defalt_path(): fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) return fastnlp_cache_dir + def _get_base_url(name): # 返回的URL结尾必须是/ if 'FASTNLP_BASE_URL' in os.environ: @@ -105,6 +112,7 @@ def _get_base_url(name): return fastnlp_base_url raise RuntimeError("There function is not available right now.") + def split_filename_suffix(filepath): """ 给定filepath返回对应的name和suffix @@ -116,6 +124,7 @@ def split_filename_suffix(filepath): return filename[:-7], '.tar.gz' return os.path.splitext(filename) + def get_from_cache(url: str, cache_dir: Path = None) -> Path: """ 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 @@ -226,6 +235,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: return get_filepath(cache_path) + def unzip_file(file: Path, to: Path): # unpack and write out in CoNLL column-like format from zipfile import ZipFile @@ -234,13 +244,15 @@ def unzip_file(file: Path, to: Path): # Extract all the contents of zip file in current directory zipObj.extractall(to) + def untar_gz_file(file:Path, to:Path): import tarfile with tarfile.open(file, 'r:gz') as tar: tar.extractall(to) -def match_file(dir_name:str, cache_dir:str)->str: + +def match_file(dir_name: str, cache_dir: str) -> str: """ 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 @@ -261,6 +273,7 @@ def match_file(dir_name:str, cache_dir:str)->str: else: raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") + if __name__ == '__main__': cache_dir = Path('caches') cache_dir = None