Browse Source

更新一些过时代码

tags/v0.4.10
xuyige 5 years ago
parent
commit
28d9ae0778
4 changed files with 20 additions and 4 deletions
  1. +1
    -1
      fastNLP/core/metrics.py
  2. +2
    -2
      fastNLP/io/data_loader/matching.py
  3. +3
    -0
      fastNLP/io/data_loader/sst.py
  4. +14
    -1
      fastNLP/io/file_utils.py

+ 1
- 1
fastNLP/core/metrics.py View File

@@ -737,7 +737,7 @@ def _pred_topk(y_prob, k=1):




class ExtractiveQAMetric(MetricBase): class ExtractiveQAMetric(MetricBase):
"""
r"""
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` 别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric`


抽取式QA(如SQuAD)的metric. 抽取式QA(如SQuAD)的metric.


+ 2
- 2
fastNLP/io/data_loader/matching.py View File

@@ -33,8 +33,8 @@ class MatchingLoader(DataSetLoader):
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None, cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
extra_split: List[str]=['-'], ) -> DataBundle:
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None,
extra_split: List[str]=List['-'], ) -> DataBundle:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和


+ 3
- 0
fastNLP/io/data_loader/sst.py View File

@@ -114,6 +114,9 @@ class SST2Loader(CSVLoader):


def _load(self, path: str) -> DataSet: def _load(self, path: str) -> DataSet:
ds = super(SST2Loader, self)._load(path) ds = super(SST2Loader, self)._load(path)
for k, v in self.field.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT)
print("all count:", len(ds)) print("all count:", len(ds))
return ds return ds


+ 14
- 1
fastNLP/io/file_utils.py View File

@@ -17,6 +17,10 @@ PRETRAINED_BERT_MODEL_DIR = {
'en-large-uncased': 'bert-large-uncased-20939f45.zip', 'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', 'en-large-cased': 'bert-large-cased-e0cf90fc.zip',


'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip',
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip',
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip',

'cn': 'bert-base-chinese-29d0a84a.zip', 'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip', 'cn-base': 'bert-base-chinese-29d0a84a.zip',


@@ -68,6 +72,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
"unable to parse {} as a URL or as a local path".format(url_or_filename) "unable to parse {} as a URL or as a local path".format(url_or_filename)
) )



def get_filepath(filepath): def get_filepath(filepath):
""" """
如果filepath中只有一个文件,则直接返回对应的全路径 如果filepath中只有一个文件,则直接返回对应的全路径
@@ -82,6 +87,7 @@ def get_filepath(filepath):
return filepath return filepath
return filepath return filepath



def get_defalt_path(): def get_defalt_path():
""" """
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。
@@ -98,6 +104,7 @@ def get_defalt_path():
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP"))
return fastnlp_cache_dir return fastnlp_cache_dir



def _get_base_url(name): def _get_base_url(name):
# 返回的URL结尾必须是/ # 返回的URL结尾必须是/
if 'FASTNLP_BASE_URL' in os.environ: if 'FASTNLP_BASE_URL' in os.environ:
@@ -105,6 +112,7 @@ def _get_base_url(name):
return fastnlp_base_url return fastnlp_base_url
raise RuntimeError("There function is not available right now.") raise RuntimeError("There function is not available right now.")



def split_filename_suffix(filepath): def split_filename_suffix(filepath):
""" """
给定filepath返回对应的name和suffix 给定filepath返回对应的name和suffix
@@ -116,6 +124,7 @@ def split_filename_suffix(filepath):
return filename[:-7], '.tar.gz' return filename[:-7], '.tar.gz'
return os.path.splitext(filename) return os.path.splitext(filename)



def get_from_cache(url: str, cache_dir: Path = None) -> Path: def get_from_cache(url: str, cache_dir: Path = None) -> Path:
""" """
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。
@@ -226,6 +235,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:


return get_filepath(cache_path) return get_filepath(cache_path)



def unzip_file(file: Path, to: Path): def unzip_file(file: Path, to: Path):
# unpack and write out in CoNLL column-like format # unpack and write out in CoNLL column-like format
from zipfile import ZipFile from zipfile import ZipFile
@@ -234,13 +244,15 @@ def unzip_file(file: Path, to: Path):
# Extract all the contents of zip file in current directory # Extract all the contents of zip file in current directory
zipObj.extractall(to) zipObj.extractall(to)



def untar_gz_file(file:Path, to:Path): def untar_gz_file(file:Path, to:Path):
import tarfile import tarfile


with tarfile.open(file, 'r:gz') as tar: with tarfile.open(file, 'r:gz') as tar:
tar.extractall(to) tar.extractall(to)


def match_file(dir_name:str, cache_dir:str)->str:

def match_file(dir_name: str, cache_dir: str) -> str:
""" """
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串
@@ -261,6 +273,7 @@ def match_file(dir_name:str, cache_dir:str)->str:
else: else:
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.")



if __name__ == '__main__': if __name__ == '__main__':
cache_dir = Path('caches') cache_dir = Path('caches')
cache_dir = None cache_dir = None


Loading…
Cancel
Save