Browse Source

更新一些过时代码

tags/v0.4.10
xuyige 5 years ago
parent
commit
28d9ae0778
4 changed files with 20 additions and 4 deletions
  1. +1
    -1
      fastNLP/core/metrics.py
  2. +2
    -2
      fastNLP/io/data_loader/matching.py
  3. +3
    -0
      fastNLP/io/data_loader/sst.py
  4. +14
    -1
      fastNLP/io/file_utils.py

+ 1
- 1
fastNLP/core/metrics.py View File

@@ -737,7 +737,7 @@ def _pred_topk(y_prob, k=1):


class ExtractiveQAMetric(MetricBase):
"""
r"""
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric`

抽取式QA(如SQuAD)的metric.


+ 2
- 2
fastNLP/io/data_loader/matching.py View File

@@ -33,8 +33,8 @@ class MatchingLoader(DataSetLoader):
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
extra_split: List[str]=['-'], ) -> DataBundle:
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None,
extra_split: List[str]=List['-'], ) -> DataBundle:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和


+ 3
- 0
fastNLP/io/data_loader/sst.py View File

@@ -114,6 +114,9 @@ class SST2Loader(CSVLoader):

def _load(self, path: str) -> DataSet:
ds = super(SST2Loader, self)._load(path)
for k, v in self.field.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT)
print("all count:", len(ds))
return ds


+ 14
- 1
fastNLP/io/file_utils.py View File

@@ -17,6 +17,10 @@ PRETRAINED_BERT_MODEL_DIR = {
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip',
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip',
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

@@ -68,6 +72,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
"unable to parse {} as a URL or as a local path".format(url_or_filename)
)


def get_filepath(filepath):
"""
如果filepath中只有一个文件,则直接返回对应的全路径
@@ -82,6 +87,7 @@ def get_filepath(filepath):
return filepath
return filepath


def get_defalt_path():
"""
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。
@@ -98,6 +104,7 @@ def get_defalt_path():
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP"))
return fastnlp_cache_dir


def _get_base_url(name):
# 返回的URL结尾必须是/
if 'FASTNLP_BASE_URL' in os.environ:
@@ -105,6 +112,7 @@ def _get_base_url(name):
return fastnlp_base_url
raise RuntimeError("There function is not available right now.")


def split_filename_suffix(filepath):
"""
给定filepath返回对应的name和suffix
@@ -116,6 +124,7 @@ def split_filename_suffix(filepath):
return filename[:-7], '.tar.gz'
return os.path.splitext(filename)


def get_from_cache(url: str, cache_dir: Path = None) -> Path:
"""
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。
@@ -226,6 +235,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:

return get_filepath(cache_path)


def unzip_file(file: Path, to: Path):
# unpack and write out in CoNLL column-like format
from zipfile import ZipFile
@@ -234,13 +244,15 @@ def unzip_file(file: Path, to: Path):
# Extract all the contents of zip file in current directory
zipObj.extractall(to)


def untar_gz_file(file:Path, to:Path):
import tarfile

with tarfile.open(file, 'r:gz') as tar:
tar.extractall(to)

def match_file(dir_name:str, cache_dir:str)->str:

def match_file(dir_name: str, cache_dir: str) -> str:
"""
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串
@@ -261,6 +273,7 @@ def match_file(dir_name:str, cache_dir:str)->str:
else:
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.")


if __name__ == '__main__':
cache_dir = Path('caches')
cache_dir = None


Loading…
Cancel
Save