|
- # Modified from https://github.com/pytorch/vision
- import os
- import os.path
- import hashlib
- import errno
- from tqdm import tqdm
- import numpy as np
- import torch
- import random
-
- def mkdir(dir):
- if not os.path.isdir(dir):
- os.mkdir(dir)
-
- def colormap(N=256, normalized=False):
- def bitget(byteval, idx):
- return ((byteval & (1 << idx)) != 0)
-
- dtype = 'float32' if normalized else 'uint8'
- cmap = np.zeros((N, 3), dtype=dtype)
- for i in range(N):
- r = g = b = 0
- c = i
- for j in range(8):
- r = r | (bitget(c, 0) << 7-j)
- g = g | (bitget(c, 1) << 7-j)
- b = b | (bitget(c, 2) << 7-j)
- c = c >> 3
-
- cmap[i] = np.array([r, g, b])
-
- cmap = cmap/255 if normalized else cmap
- return cmap
-
- DEFAULT_COLORMAP = colormap()
-
- def gen_bar_updater(pbar):
- def bar_update(count, block_size, total_size):
- if pbar.total is None and total_size:
- pbar.total = total_size
- progress_bytes = count * block_size
- pbar.update(progress_bytes - pbar.n)
-
- return bar_update
-
-
- def check_integrity(fpath, md5=None):
- if md5 is None:
- return True
- if not os.path.isfile(fpath):
- return False
- md5o = hashlib.md5()
- with open(fpath, 'rb') as f:
- # read in 1MB chunks
- for chunk in iter(lambda: f.read(1024 * 1024), b''):
- md5o.update(chunk)
- md5c = md5o.hexdigest()
- if md5c != md5:
- return False
- return True
-
-
- def makedir_exist_ok(dirpath):
- """
- Python2 support for os.makedirs(.., exist_ok=True)
- """
- try:
- os.makedirs(dirpath)
- except OSError as e:
- if e.errno == errno.EEXIST:
- pass
- else:
- raise
-
- def download_url(url, root, filename=None, md5=None):
- """Download a file from a url and place it in root.
- Args:
- url (str): URL to download file from
- root (str): Directory to place downloaded file in
- filename (str): Name to save the file under. If None, use the basename of the URL
- md5 (str): MD5 checksum of the download. If None, do not check
- """
- from six.moves import urllib
-
- root = os.path.expanduser(root)
- if not filename:
- filename = os.path.basename(url)
- fpath = os.path.join(root, filename)
-
- makedir_exist_ok(root)
-
- # downloads file
- if os.path.isfile(fpath) and check_integrity(fpath, md5):
- print('Using downloaded and verified file: ' + fpath)
- else:
- try:
- print('Downloading ' + url + ' to ' + fpath)
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
- )
- except OSError:
- if url[:5] == 'https':
- url = url.replace('https:', 'http:')
- print('Failed download. Trying https -> http instead.'
- ' Downloading ' + url + ' to ' + fpath)
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
- )
-
-
- def list_dir(root, prefix=False):
- """List all directories at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the directories found
- """
- root = os.path.expanduser(root)
- directories = list(
- filter(
- lambda p: os.path.isdir(os.path.join(root, p)),
- os.listdir(root)
- )
- )
-
- if prefix is True:
- directories = [os.path.join(root, d) for d in directories]
-
- return directories
-
-
- def list_files(root, suffix, prefix=False):
- """List all files ending with a suffix at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
- It uses the Python "str.endswith" method and is passed directly
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the files found
- """
- root = os.path.expanduser(root)
- files = list(
- filter(
- lambda p: os.path.isfile(os.path.join(
- root, p)) and p.endswith(suffix),
- os.listdir(root)
- )
- )
-
- if prefix is True:
- files = [os.path.join(root, d) for d in files]
-
- return files
-
- def set_seed(random_seed):
- torch.manual_seed(random_seed)
- torch.cuda.manual_seed(random_seed)
- np.random.seed(random_seed)
- random.seed(random_seed)
|