| """ |
| This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py |
| """ |
| import os |
| import hashlib |
| import requests |
| from tqdm import tqdm |
|
|
|
|
| def check_sha1(filename, sha1_hash): |
| """Check whether the sha1 hash of the file content matches the expected hash. |
| Parameters |
| ---------- |
| filename : str |
| Path to the file. |
| sha1_hash : str |
| Expected sha1 hash in hexadecimal digits. |
| Returns |
| ------- |
| bool |
| Whether the file content matches the expected hash. |
| """ |
| sha1 = hashlib.sha1() |
| with open(filename, 'rb') as f: |
| while True: |
| data = f.read(1048576) |
| if not data: |
| break |
| sha1.update(data) |
|
|
| sha1_file = sha1.hexdigest() |
| l = min(len(sha1_file), len(sha1_hash)) |
| return sha1.hexdigest()[0:l] == sha1_hash[0:l] |
|
|
|
|
| def download_file(url, path=None, overwrite=False, sha1_hash=None): |
| """Download an given URL |
| Parameters |
| ---------- |
| url : str |
| URL to download |
| path : str, optional |
| Destination path to store downloaded file. By default stores to the |
| current directory with same name as in url. |
| overwrite : bool, optional |
| Whether to overwrite destination file if already exists. |
| sha1_hash : str, optional |
| Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified |
| but doesn't match. |
| Returns |
| ------- |
| str |
| The file path of the downloaded file. |
| """ |
| if path is None: |
| fname = url.split('/')[-1] |
| else: |
| path = os.path.expanduser(path) |
| if os.path.isdir(path): |
| fname = os.path.join(path, url.split('/')[-1]) |
| else: |
| fname = path |
|
|
| if overwrite or not os.path.exists(fname) or ( |
| sha1_hash and not check_sha1(fname, sha1_hash)): |
| dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) |
| if not os.path.exists(dirname): |
| os.makedirs(dirname) |
|
|
| print('Downloading %s from %s...' % (fname, url)) |
| r = requests.get(url, stream=True) |
| if r.status_code != 200: |
| raise RuntimeError("Failed downloading url %s" % url) |
| total_length = r.headers.get('content-length') |
| with open(fname, 'wb') as f: |
| if total_length is None: |
| for chunk in r.iter_content(chunk_size=1024): |
| if chunk: |
| f.write(chunk) |
| else: |
| total_length = int(total_length) |
| for chunk in tqdm(r.iter_content(chunk_size=1024), |
| total=int(total_length / 1024. + 0.5), |
| unit='KB', |
| unit_scale=False, |
| dynamic_ncols=True): |
| f.write(chunk) |
|
|
| if sha1_hash and not check_sha1(fname, sha1_hash): |
| raise UserWarning('File {} is downloaded but the content hash does not match. ' \ |
| 'The repo may be outdated or download may be incomplete. ' \ |
| 'If the "repo_url" is overridden, consider switching to ' \ |
| 'the default repo.'.format(fname)) |
|
|
| return fname |
|
|