diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index bd14cf58..53adfd62 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -182,8 +182,8 @@ class _ElmoModel(nn.Module): raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") elif config_count == 0 or weight_count == 0: raise Exception(f"No config file or weight file found in {model_dir}") - - config = json.load(open(os.path.join(model_dir, config_file), 'r')) + with open(os.path.join(model_dir, config_file), 'r') as config_f: + config = json.load(config_f) self.weight_file = os.path.join(model_dir, weight_file) self.config = config diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index cb762eb7..4be1360b 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -11,7 +11,7 @@ import hashlib PRETRAINED_BERT_MODEL_DIR = { - 'en': 'bert-base-cased-f89bfe08.zip', + 'en': 'bert-large-cased-wwm.zip', 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', 'en-base-cased': 'bert-base-cased-f89bfe08.zip', 'en-large-uncased': 'bert-large-uncased-20939f45.zip', @@ -24,14 +24,14 @@ PRETRAINED_BERT_MODEL_DIR = { 'cn': 'bert-base-chinese-29d0a84a.zip', 'cn-base': 'bert-base-chinese-29d0a84a.zip', - 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', - 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', - 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', + 'multilingual': 'bert-base-multilingual-cased.zip', + 'multilingual-base-uncased': 'bert-base-multilingual-uncased.zip', + 'multilingual-base-cased': 'bert-base-multilingual-cased.zip', } PRETRAINED_ELMO_MODEL_DIR = { 'en': 'elmo_en-d39843fe.tar.gz', - 'cn': 'elmo_cn-5e9b34e2.tar.gz' + 'en-small': "elmo_en_Small.zip" } PRETRAIN_STATIC_FILES = { @@ -39,7 +39,7 @@ PRETRAIN_STATIC_FILES = { 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", - 'en-fasttext': "cc.en.300.vec-d53187b2.gz", + 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", 'cn': "tencent_cn-dab24577.tar.gz", 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", } @@ -47,11 +47,15 @@ PRETRAIN_STATIC_FILES = { def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: """ - 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 - 将文件放入到cache_dir中 + 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 将文件放入到cache_dir中. + + :param url_or_filename: 文件的下载url或者文件路径 + :param cache_dir: 文件的缓存文件夹 + :return: """ if cache_dir is None: - dataset_cache = Path(get_defalt_path()) + dataset_cache = Path(get_default_cache_path()) else: dataset_cache = cache_dir @@ -75,7 +79,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: def get_filepath(filepath): """ - 如果filepath中只有一个文件,则直接返回对应的全路径 + 如果filepath中只有一个文件,则直接返回对应的全路径. :param filepath: :return: """ @@ -88,7 +92,7 @@ def get_filepath(filepath): return filepath -def get_defalt_path(): +def get_default_cache_path(): """ 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 @@ -96,11 +100,10 @@ def get_defalt_path(): """ if 'FASTNLP_CACHE_DIR' in os.environ: fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') - if os.path.exists(fastnlp_cache_dir): + if os.path.isdir(fastnlp_cache_dir): return fastnlp_cache_dir - raise RuntimeError("Some errors happens on cache directory.") - else: - raise RuntimeError("There function is not available right now.") + else: + raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.") fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) return fastnlp_cache_dir @@ -109,13 +112,19 @@ def _get_base_url(name): # 返回的URL结尾必须是/ if 'FASTNLP_BASE_URL' in os.environ: fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] - return fastnlp_base_url - raise RuntimeError("There function is not available right now.") + if fastnlp_base_url.endswith('/'): + return fastnlp_base_url + else: + return fastnlp_base_url + '/' + else: + # TODO 替换 + dbbrain_url = "http://dbcloud.irocn.cn:8989/api/public/dl/" + return dbbrain_url def split_filename_suffix(filepath): """ - 给定filepath返回对应的name和suffix + 给定filepath返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 :param filepath: :return: filename, suffix """ @@ -135,13 +144,6 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: filename = re.sub(r".+/", "", url) dir_name, suffix = split_filename_suffix(filename) - sep_index = dir_name[::-1].index('-') - if sep_index<0: - check_sum = None - else: - check_sum = dir_name[-sep_index+1:] - sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 - dir_name = dir_name[:sep_index] # 寻找与它名字匹配的内容, 而不关心后缀 match_dir_name = match_file(dir_name, cache_dir) @@ -154,11 +156,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: return get_filepath(cache_path) # make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 - response = requests.head(url, headers={"User-Agent": "fastNLP"}) - if response.status_code != 200: - raise IOError( - f"HEAD request failed for url {url} with status code {response.status_code}." - ) + # response = requests.head(url, headers={"User-Agent": "fastNLP"}) + # if response.status_code != 200: + # raise IOError( + # f"HEAD request failed for url {url} with status code {response.status_code}." + # ) # add ETag to filename if it exists # etag = response.headers.get("ETag") @@ -174,17 +176,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: content_length = req.headers.get("Content-Length") total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) - sha256 = hashlib.sha256() with open(temp_filename, "wb") as temp_file: for chunk in req.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) - sha256.update(chunk) - # check sum - digit = sha256.hexdigest()[:8] - if not check_sum: - assert digit == check_sum, "File corrupted when download." progress.close() print(f"Finish download from {url}.") @@ -193,7 +189,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: if suffix in ('.zip', '.tar.gz'): uncompress_temp_dir = tempfile.mkdtemp() delete_temp_dir = uncompress_temp_dir - print(f"Start to uncompress file to {uncompress_temp_dir}.") + print(f"Start to uncompress file to {uncompress_temp_dir}") if suffix == '.zip': unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) else: @@ -211,7 +207,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: success = False try: # 复制到指定的位置 - print(f"Copy file to {cache_path}.") + print(f"Copy file to {cache_path}") if os.path.isdir(uncompress_temp_dir): for filename in os.listdir(uncompress_temp_dir): shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) @@ -252,7 +248,7 @@ def untar_gz_file(file:Path, to:Path): tar.extractall(to) -def match_file(dir_name: str, cache_dir: str) -> str: +def match_file(dir_name: str, cache_dir: Path) -> str: """ 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 @@ -273,27 +269,3 @@ def match_file(dir_name: str, cache_dir: str) -> str: else: raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") - -if __name__ == '__main__': - cache_dir = Path('caches') - cache_dir = None - # 需要对cache_dir进行测试 - base_url = 'http://0.0.0.0:8888/file/download' - # if True: - # for filename in os.listdir(cache_dir): - # if os.path.isdir(os.path.join(cache_dir, filename)): - # shutil.rmtree(os.path.join(cache_dir, filename)) - # else: - # os.remove(os.path.join(cache_dir, filename)) - # 1. 测试.txt文件 - print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) - # 2. 测试.zip文件(只有一个文件) - print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) - # 3. 测试.zip文件(有多个文件) - print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) - # 4. 测试.tar.gz文件 - print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) - # 5. 测试.tar.gz多个文件 - print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) - - # 6. 测试.pkl文件 diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 9a990d9d..e73b2c40 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -563,7 +563,7 @@ class WordpieceTokenizer(object): output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) - if len(output_tokens)==0: + if len(output_tokens)==0: #防止里面全是空格或者回车符号 return [self.unk_token] return output_tokens