Wazzzabeee's picture
Update utils.py
1348930
raw
history blame
2.19 kB
import numpy as np
import requests
import streamlit as st
from PIL import Image
from models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img, eccv16, siggraph17
# Define a function that we can use to load lottie files from a link.
@st.cache_data()
def load_lottieurl(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
@st.cache_resource()
def change_model(current_model, model):
if current_model != model:
if model == "ECCV16":
loaded_model = eccv16(pretrained=True).eval()
elif model == "SIGGRAPH17":
loaded_model = siggraph17(pretrained=True).eval()
return loaded_model
else:
raise Exception("Model is the same as the current one.")
def format_time(seconds: float) -> str:
"""Formats time in seconds to a human readable format"""
if seconds < 60:
return f"{int(seconds)} seconds"
elif seconds < 3600:
minutes = seconds // 60
seconds %= 60
return f"{minutes} minutes and {int(seconds)} seconds"
elif seconds < 86400:
hours = seconds // 3600
minutes = (seconds % 3600) // 60
seconds %= 60
return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds"
else:
days = seconds // 86400
hours = (seconds % 86400) // 3600
minutes = (seconds % 3600) // 60
seconds %= 60
return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds"
# Function to colorize video frames
def colorize_frame(frame, colorizer) -> np.ndarray:
tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256))
return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
def colorize_image(file, loaded_model):
img = load_img(file)
# If user input a colored image with 4 channels, discard the fourth channel
if img.shape[2] == 4:
img = img[:, :, :3]
tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu())
new_img = Image.fromarray((out_img * 255).astype(np.uint8))
return out_img, new_img