|
|
@@ -0,0 +1,138 @@ |
|
|
|
import torch |
|
|
|
|
|
|
|
import hashlib |
|
|
|
import os |
|
|
|
import re |
|
|
|
import shutil |
|
|
|
import sys |
|
|
|
import tempfile |
|
|
|
|
|
|
|
try: |
|
|
|
from requests.utils import urlparse |
|
|
|
from requests import get as urlopen |
|
|
|
requests_available = True |
|
|
|
except ImportError: |
|
|
|
requests_available = False |
|
|
|
if sys.version_info[0] == 2: |
|
|
|
from urlparse import urlparse # noqa f811 |
|
|
|
from urllib2 import urlopen # noqa f811 |
|
|
|
else: |
|
|
|
from urllib.request import urlopen |
|
|
|
from urllib.parse import urlparse |
|
|
|
try: |
|
|
|
from tqdm import tqdm |
|
|
|
except ImportError: |
|
|
|
tqdm = None # defined below |
|
|
|
|
|
|
|
# matches bfd8deac from resnet18-bfd8deac.pth |
|
|
|
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') |
|
|
|
|
|
|
|
|
|
|
|
def load_url(url, model_dir=None, map_location=None, progress=True): |
|
|
|
r"""Loads the Torch serialized object at the given URL. |
|
|
|
|
|
|
|
If the object is already present in `model_dir`, it's deserialized and |
|
|
|
returned. The filename part of the URL should follow the naming convention |
|
|
|
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more |
|
|
|
digits of the SHA256 hash of the contents of the file. The hash is used to |
|
|
|
ensure unique names and to verify the contents of the file. |
|
|
|
|
|
|
|
The default value of `model_dir` is ``$TORCH_HOME/models`` where |
|
|
|
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be |
|
|
|
overridden with the ``$TORCH_MODEL_ZOO`` environment variable. |
|
|
|
|
|
|
|
Args: |
|
|
|
url (string): URL of the object to download |
|
|
|
model_dir (string, optional): directory in which to save the object |
|
|
|
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) |
|
|
|
progress (bool, optional): whether or not to display a progress bar to stderr |
|
|
|
|
|
|
|
Example: |
|
|
|
# >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') |
|
|
|
|
|
|
|
""" |
|
|
|
if model_dir is None: |
|
|
|
torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP')) |
|
|
|
model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models')) |
|
|
|
if not os.path.exists(model_dir): |
|
|
|
os.makedirs(model_dir) |
|
|
|
parts = urlparse(url) |
|
|
|
filename = os.path.basename(parts.path) |
|
|
|
cached_file = os.path.join(model_dir, filename) |
|
|
|
if not os.path.exists(cached_file): |
|
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
|
|
# hash_prefix = HASH_REGEX.search(filename).group(1) |
|
|
|
_download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) |
|
|
|
return torch.load(cached_file, map_location=map_location) |
|
|
|
|
|
|
|
|
|
|
|
def _download_url_to_file(url, dst, hash_prefix, progress): |
|
|
|
if requests_available: |
|
|
|
u = urlopen(url, stream=True) |
|
|
|
file_size = int(u.headers["Content-Length"]) |
|
|
|
u = u.raw |
|
|
|
else: |
|
|
|
u = urlopen(url) |
|
|
|
meta = u.info() |
|
|
|
if hasattr(meta, 'getheaders'): |
|
|
|
file_size = int(meta.getheaders("Content-Length")[0]) |
|
|
|
else: |
|
|
|
file_size = int(meta.get_all("Content-Length")[0]) |
|
|
|
|
|
|
|
f = tempfile.NamedTemporaryFile(delete=False) |
|
|
|
try: |
|
|
|
if hash_prefix is not None: |
|
|
|
sha256 = hashlib.sha256() |
|
|
|
with tqdm(total=file_size, disable=not progress) as pbar: |
|
|
|
while True: |
|
|
|
buffer = u.read(8192) |
|
|
|
if len(buffer) == 0: |
|
|
|
break |
|
|
|
f.write(buffer) |
|
|
|
if hash_prefix is not None: |
|
|
|
sha256.update(buffer) |
|
|
|
pbar.update(len(buffer)) |
|
|
|
|
|
|
|
f.close() |
|
|
|
if hash_prefix is not None: |
|
|
|
digest = sha256.hexdigest() |
|
|
|
if digest[:len(hash_prefix)] != hash_prefix: |
|
|
|
raise RuntimeError('invalid hash value (expected "{}", got "{}")' |
|
|
|
.format(hash_prefix, digest)) |
|
|
|
shutil.move(f.name, dst) |
|
|
|
finally: |
|
|
|
f.close() |
|
|
|
if os.path.exists(f.name): |
|
|
|
os.remove(f.name) |
|
|
|
|
|
|
|
|
|
|
|
if tqdm is None: |
|
|
|
# fake tqdm if it's not installed |
|
|
|
class tqdm(object): |
|
|
|
|
|
|
|
def __init__(self, total, disable=False): |
|
|
|
self.total = total |
|
|
|
self.disable = disable |
|
|
|
self.n = 0 |
|
|
|
|
|
|
|
def update(self, n): |
|
|
|
if self.disable: |
|
|
|
return |
|
|
|
|
|
|
|
self.n += n |
|
|
|
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) |
|
|
|
sys.stderr.flush() |
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
return self |
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
|
if self.disable: |
|
|
|
return |
|
|
|
|
|
|
|
sys.stderr.write('\n') |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context.pkl', model_dir='.') |
|
|
|
print(type(pipeline)) |