|
|
@@ -1,4 +1,3 @@ |
|
|
|
|
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
from urllib.parse import urlparse |
|
|
@@ -9,35 +8,29 @@ from tqdm import tqdm |
|
|
|
import shutil |
|
|
|
from requests import HTTPError |
|
|
|
|
|
|
|
|
|
|
|
PRETRAINED_BERT_MODEL_DIR = { |
|
|
|
'en': 'bert-large-cased-wwm.zip', |
|
|
|
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', |
|
|
|
'en-base-cased': 'bert-base-cased-f89bfe08.zip', |
|
|
|
'en-large-uncased': 'bert-large-uncased-20939f45.zip', |
|
|
|
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', |
|
|
|
|
|
|
|
'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', |
|
|
|
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', |
|
|
|
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', |
|
|
|
|
|
|
|
'cn': 'bert-base-chinese-29d0a84a.zip', |
|
|
|
'cn-base': 'bert-base-chinese-29d0a84a.zip', |
|
|
|
'bert-base-chinese': 'bert-base-chinese.zip', |
|
|
|
'bert-base-cased': 'bert-base-cased.zip', |
|
|
|
'bert-base-cased-finetuned-mrpc': 'bert-base-cased-finetuned-mrpc.zip', |
|
|
|
'bert-large-cased-wwm': 'bert-large-cased-wwm.zip', |
|
|
|
'bert-large-uncased': 'bert-large-uncased.zip', |
|
|
|
'bert-large-cased': 'bert-large-cased.zip', |
|
|
|
'bert-base-uncased': 'bert-base-uncased.zip', |
|
|
|
'bert-large-uncased-wwm': 'bert-large-uncased-wwm.zip', |
|
|
|
'bert-chinese-wwm': 'bert-chinese-wwm.zip', |
|
|
|
'bert-base-multilingual-cased': 'bert-base-multilingual-cased.zip', |
|
|
|
'bert-base-multilingual-uncased': 'bert-base-multilingual-uncased.zip', |
|
|
|
'en-large-cased-wwm': 'bert-large-cased-wwm.zip', |
|
|
|
'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', |
|
|
|
|
|
|
|
'en-large-uncased': 'bert-large-uncased.zip', |
|
|
|
'en-large-cased': 'bert-large-cased.zip', |
|
|
|
|
|
|
|
'en-base-uncased': 'bert-base-uncased.zip', |
|
|
|
'en-base-cased': 'bert-base-cased.zip', |
|
|
|
|
|
|
|
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', |
|
|
|
|
|
|
|
'en-base-multi-cased': 'bert-base-multilingual-cased.zip', |
|
|
|
'en-base-multi-uncased': 'bert-base-multilingual-uncased.zip', |
|
|
|
|
|
|
|
'cn': 'bert-chinese-wwm.zip', |
|
|
|
'cn-base': 'bert-base-chinese.zip', |
|
|
|
'cn-wwm': 'bert-chinese-wwm.zip', |
|
|
|
} |
|
|
|
|
|
|
|
PRETRAINED_ELMO_MODEL_DIR = { |
|
|
|
'en': 'elmo_en-d39843fe.tar.gz', |
|
|
|
'en': 'elmo_en_Medium.tar.gz', |
|
|
|
'en-small': "elmo_en_Small.zip", |
|
|
|
'en-original-5.5b': 'elmo_en_Original_5.5B.zip', |
|
|
|
'en-original': 'elmo_en_Original.zip', |
|
|
@@ -45,30 +38,33 @@ PRETRAINED_ELMO_MODEL_DIR = { |
|
|
|
} |
|
|
|
|
|
|
|
PRETRAIN_STATIC_FILES = { |
|
|
|
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', |
|
|
|
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', |
|
|
|
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", |
|
|
|
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", |
|
|
|
'en': 'glove.840B.300d.tar.gz', |
|
|
|
|
|
|
|
'en-glove-6b-50d': 'glove.6B.50d.zip', |
|
|
|
'en-glove-6b-100d': 'glove.6B.100d.zip', |
|
|
|
'en-glove-6b-200d': 'glove.6B.200d.zip', |
|
|
|
'en-glove-6b-300d': 'glove.6B.300d.zip', |
|
|
|
'en-glove-42b-300d': 'glove.42B.300d.zip', |
|
|
|
'en-glove-840b-300d': 'glove.840B.300d.zip', |
|
|
|
'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', |
|
|
|
'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', |
|
|
|
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', |
|
|
|
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', |
|
|
|
|
|
|
|
'en-word2vec-300': "GoogleNews-vectors-negative300.zip", |
|
|
|
|
|
|
|
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", |
|
|
|
'cn': "tencent_cn-dab24577.tar.gz", |
|
|
|
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", |
|
|
|
'sgns-literature-word':'sgns.literature.word.txt.zip', |
|
|
|
'glove-42b-300d': 'glove.42B.300d.zip', |
|
|
|
'glove-6b-50d': 'glove.6B.50d.zip', |
|
|
|
'glove-6b-100d': 'glove.6B.100d.zip', |
|
|
|
'glove-6b-200d': 'glove.6B.200d.zip', |
|
|
|
'glove-6b-300d': 'glove.6B.300d.zip', |
|
|
|
'glove-840b-300d': 'glove.840B.300d.zip', |
|
|
|
'glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', |
|
|
|
'glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', |
|
|
|
'glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', |
|
|
|
'glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip' |
|
|
|
} |
|
|
|
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", |
|
|
|
|
|
|
|
'cn': "tencent_cn.txt.zip", |
|
|
|
'cn-tencent': "tencent_cn.txt.zip", |
|
|
|
'cn-fasttext': "cc.zh.300.vec.gz", |
|
|
|
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', |
|
|
|
} |
|
|
|
|
|
|
|
DATASET_DIR = { |
|
|
|
'aclImdb': "imdb.zip", |
|
|
|
"yelp-review-full":"yelp_review_full.tar.gz", |
|
|
|
"yelp-review-full": "yelp_review_full.tar.gz", |
|
|
|
"yelp-review-polarity": "yelp_review_polarity.tar.gz", |
|
|
|
"mnli": "MNLI.zip", |
|
|
|
"snli": "SNLI.zip", |
|
|
@@ -79,7 +75,7 @@ DATASET_DIR = { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: |
|
|
|
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: |
|
|
|
""" |
|
|
|
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, |
|
|
|
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir |
|
|
@@ -136,7 +132,7 @@ def get_filepath(filepath): |
|
|
|
""" |
|
|
|
if os.path.isdir(filepath): |
|
|
|
files = os.listdir(filepath) |
|
|
|
if len(files)==1: |
|
|
|
if len(files) == 1: |
|
|
|
return os.path.join(filepath, files[0]) |
|
|
|
else: |
|
|
|
return filepath |
|
|
@@ -180,9 +176,9 @@ def _get_base_url(name): |
|
|
|
return url + '/' |
|
|
|
else: |
|
|
|
URLS = { |
|
|
|
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", |
|
|
|
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" |
|
|
|
} |
|
|
|
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", |
|
|
|
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" |
|
|
|
} |
|
|
|
if name.lower() not in URLS: |
|
|
|
raise KeyError(f"{name} is not recognized.") |
|
|
|
return URLS[name.lower()] |
|
|
@@ -198,7 +194,7 @@ def _get_embedding_url(type, name): |
|
|
|
""" |
|
|
|
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, |
|
|
|
"bert": PRETRAINED_BERT_MODEL_DIR, |
|
|
|
"static":PRETRAIN_STATIC_FILES} |
|
|
|
"static": PRETRAIN_STATIC_FILES} |
|
|
|
map = PRETRAIN_MAP.get(type, None) |
|
|
|
if map: |
|
|
|
filename = map.get(name, None) |
|
|
@@ -273,16 +269,16 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: |
|
|
|
# Download to temporary file, then copy to cache dir once finished. |
|
|
|
# Otherwise you get corrupt cache entries if the download gets interrupted. |
|
|
|
fd, temp_filename = tempfile.mkstemp() |
|
|
|
print("%s not found in cache, downloading to %s"%(url, temp_filename)) |
|
|
|
print("%s not found in cache, downloading to %s" % (url, temp_filename)) |
|
|
|
|
|
|
|
# GET file object |
|
|
|
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) |
|
|
|
if req.status_code==200: |
|
|
|
if req.status_code == 200: |
|
|
|
content_length = req.headers.get("Content-Length") |
|
|
|
total = int(content_length) if content_length is not None else None |
|
|
|
progress = tqdm(unit="B", total=total, unit_scale=1) |
|
|
|
with open(temp_filename, "wb") as temp_file: |
|
|
|
for chunk in req.iter_content(chunk_size=1024*16): |
|
|
|
for chunk in req.iter_content(chunk_size=1024 * 16): |
|
|
|
if chunk: # filter out keep-alive new chunks |
|
|
|
progress.update(len(chunk)) |
|
|
|
temp_file.write(chunk) |
|
|
@@ -300,7 +296,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: |
|
|
|
else: |
|
|
|
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) |
|
|
|
filenames = os.listdir(uncompress_temp_dir) |
|
|
|
if len(filenames)==1: |
|
|
|
if len(filenames) == 1: |
|
|
|
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): |
|
|
|
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) |
|
|
|
|
|
|
@@ -316,9 +312,9 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: |
|
|
|
if os.path.isdir(uncompress_temp_dir): |
|
|
|
for filename in os.listdir(uncompress_temp_dir): |
|
|
|
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): |
|
|
|
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path/filename) |
|
|
|
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) |
|
|
|
else: |
|
|
|
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) |
|
|
|
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) |
|
|
|
else: |
|
|
|
shutil.copyfile(uncompress_temp_dir, cache_path) |
|
|
|
success = True |
|
|
@@ -350,7 +346,7 @@ def unzip_file(file: Path, to: Path): |
|
|
|
zipObj.extractall(to) |
|
|
|
|
|
|
|
|
|
|
|
def untar_gz_file(file:Path, to:Path): |
|
|
|
def untar_gz_file(file: Path, to: Path): |
|
|
|
import tarfile |
|
|
|
|
|
|
|
with tarfile.open(file, 'r:gz') as tar: |
|
|
@@ -369,12 +365,11 @@ def match_file(dir_name: str, cache_dir: Path) -> str: |
|
|
|
files = os.listdir(cache_dir) |
|
|
|
matched_filenames = [] |
|
|
|
for file_name in files: |
|
|
|
if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): |
|
|
|
if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): |
|
|
|
matched_filenames.append(file_name) |
|
|
|
if len(matched_filenames)==0: |
|
|
|
if len(matched_filenames) == 0: |
|
|
|
return '' |
|
|
|
elif len(matched_filenames)==1: |
|
|
|
elif len(matched_filenames) == 1: |
|
|
|
return matched_filenames[-1] |
|
|
|
else: |
|
|
|
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") |
|
|
|
|