"""Gradio utilities. Note that the optional `progress` parameter can be both a `tqdm` module or a `gr.Progress` instance. """ import concurrent.futures import contextlib import glob import hashlib import logging import os import tempfile import time import urllib.request import jax import numpy as np from tensorflow.io import gfile @contextlib.contextmanager def timed(name): t0 = time.monotonic() timing = dict(dt=None) try: yield timing finally: timing['secs'] = time.monotonic() - t0 logging.info('Timed %s: %.1f secs', name, timing['secs']) def copy_file( src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False ): """Copies a file with progress bar. Args: src: Source file (readable by `tf.io.gfile`) or URL. dst: Destination file. Path must be readable by `tf.io.gfile`. progress: An object with a `.tqdm` attribute, or `None`. block_size: Size of individual blocks to be read/written. overwrite: If `True`, overwrite `dst` if it exists. """ if os.path.dirname(dst): os.makedirs(os.path.dirname(dst), exist_ok=True) if os.path.exists(dst) and not overwrite: return if src.startswith('http://') or src.startswith('https://'): opener = urllib.request.urlopen request = urllib.request.Request(src, method='HEAD') response = urllib.request.urlopen(request) content_length = response.headers.get('Content-Length') n = int(np.ceil(int(content_length) / block_size)) print('content_length', content_length) else: opener = lambda path: gfile.GFile(path, 'rb') stats = gfile.stat(src) n = int(np.ceil(stats.length / block_size)) if progress is None: range_or_trange = range else: range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download') with opener(src) as fin: with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout: for _ in range_or_trange(n): fout.write(fin.read(block_size)) gfile.rename(f'{dst}-PARTIAL', dst) _estimated_real = [(10, 10)] _memory_cache = {} def get_with_progress(getter, secs, progress, step=0.1): """Returns result from `getter` while showing a progress bar.""" with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(getter) for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): if not future.done(): time.sleep(step) return future.result() def _get_array_sizes(tree): return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] def get_memory_cache( key, getter, max_cache_size_bytes, progress=None, estimated_secs=None ): """Keeps cache below specified size by removing elements not last accessed.""" if key in _memory_cache: _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order return _memory_cache[key] est, real = zip(*_estimated_real) if estimated_secs is None: estimated_secs = sum(est) / len(est) with timed(f'loading {key}') as timing: estimated_secs *= sum(real) / sum(est) _memory_cache[key] = get_with_progress(getter, estimated_secs, progress) _estimated_real.append((estimated_secs, timing['secs'])) sz = sum(_get_array_sizes(list(_memory_cache.values()))) logging.info('New memory cache size=%.1f MB', sz/1e6) while sz > max_cache_size_bytes: k, v = next(iter(_memory_cache.items())) if k == key: break s = sum(_get_array_sizes(v)) logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) _memory_cache.pop(k) sz -= s return _memory_cache[key] def get_memory_cache_info(): """Returns number of items and total size in bytes.""" sizes = _get_array_sizes(_memory_cache) return len(_memory_cache), sum(sizes) CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache') def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None): """Keeps cache below specified size by removing elements not last accessed.""" fname = os.path.basename(path_or_url) path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname dst = os.path.join(CACHE_DIR, path_hash, fname) if os.path.exists(dst): return dst os.makedirs(os.path.dirname(dst), exist_ok=True) with timed(f'copying {path_or_url}'): copy_file(path_or_url, dst, progress=progress) atimes_sizes_paths = sorted([ (os.path.getatime(p), os.path.getsize(p), p) for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) if os.path.isfile(p) ]) sz = sum(sz for _, sz, _ in atimes_sizes_paths) logging.info('New disk cache size=%.1f MB', sz/1e6) while sz > max_cache_size_bytes: _, s, path = atimes_sizes_paths.pop(0) if path == dst: break logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6) os.unlink(fname) sz -= s return dst def get_disk_cache_info(): """Returns number of items and total size in bytes.""" sizes = [ os.path.getsize(p) for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) ] return len(sizes), sum(sizes)