Clement Delteil commited on
Commit
935dd93
2 Parent(s): 70b0b9d e05bf5d

huge modif

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import streamlit as st
4
+ from models.deep_colorization.colorizers import *
5
+ import cv2
6
+ from PIL import Image
7
+ import tempfile
8
+ import moviepy.editor as mp
9
+ import time
10
+ from tqdm import tqdm
11
+
12
+
13
+ def format_time(seconds: float) -> str:
14
+ """Formats time in seconds to a human readable format"""
15
+ if seconds < 60:
16
+ return f"{int(seconds)} seconds"
17
+ elif seconds < 3600:
18
+ minutes = seconds // 60
19
+ seconds %= 60
20
+ return f"{minutes} minutes and {int(seconds)} seconds"
21
+ elif seconds < 86400:
22
+ hours = seconds // 3600
23
+ minutes = (seconds % 3600) // 60
24
+ seconds %= 60
25
+ return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds"
26
+ else:
27
+ days = seconds // 86400
28
+ hours = (seconds % 86400) // 3600
29
+ minutes = (seconds % 3600) // 60
30
+ seconds %= 60
31
+ return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds"
32
+
33
+
34
+ # Function to colorize video frames
35
+ def colorize_frame(frame, colorizer) -> np.ndarray:
36
+ tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256))
37
+ return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
38
+
39
+ image = Image.open(r'img/streamlit.png') # Brand logo image (optional)
40
+
41
+ # Create two columns with different width
42
+ col1, col2 = st.columns([0.8, 0.2])
43
+ with col1: # To display the header text using css style
44
+ st.markdown(""" <style> .font {
45
+ font-size:35px ; font-family: 'Cooper Black'; color: #FF4B4B;}
46
+ </style> """, unsafe_allow_html=True)
47
+ st.markdown('<p class="font">Upload your photo or video here...</p>', unsafe_allow_html=True)
48
+
49
+ with col2: # To display brand logo
50
+ st.image(image, width=100)
51
+
52
+ # Add a header and expander in side bar
53
+ st.sidebar.markdown('<p class="font">Color Revive App</p>', unsafe_allow_html=True)
54
+ with st.sidebar.expander("About the App"):
55
+ st.write("""
56
+ Use this simple app to colorize your black and white images and videos with state of the art models.
57
+ """)
58
+
59
+ # Add file uploader to allow users to upload photos
60
+ uploaded_file = st.file_uploader("", type=['jpg', 'png', 'jpeg', 'mp4'])
61
+
62
+ # Add 'before' and 'after' columns
63
+ if uploaded_file is not None:
64
+ file_extension = os.path.splitext(uploaded_file.name)[1].lower()
65
+
66
+ if file_extension in ['.jpg', '.png', '.jpeg']:
67
+ image = Image.open(uploaded_file)
68
+
69
+ col1, col2 = st.columns([0.5, 0.5])
70
+ with col1:
71
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
72
+ st.image(image, width=300)
73
+
74
+ # Add conditional statements to take the user input values
75
+ with col2:
76
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
77
+ filter = st.sidebar.radio('Colorize your image with:',
78
+ ['Original', 'ECCV 16', 'SIGGRAPH 17'])
79
+ if filter == 'ECCV 16':
80
+ colorizer_eccv16 = eccv16(pretrained=True).eval()
81
+ img = load_img(uploaded_file)
82
+ tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
83
+ out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
84
+ st.image(out_img_eccv16, width=300)
85
+ elif filter == 'SIGGRAPH 17':
86
+ colorizer_siggraph17 = siggraph17(pretrained=True).eval()
87
+ img = load_img(uploaded_file)
88
+ tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
89
+ out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
90
+ st.image(out_img_siggraph17, width=300)
91
+ else:
92
+ st.image(image, width=300)
93
+
94
+ elif file_extension == '.mp4': # If uploaded file is a video
95
+ # Save the video file to a temporary location
96
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
97
+ temp_file.write(uploaded_file.read())
98
+
99
+ # Open the video using cv2.VideoCapture
100
+ video = cv2.VideoCapture(temp_file.name)
101
+
102
+ # Get video information
103
+ fps = video.get(cv2.CAP_PROP_FPS)
104
+
105
+ # Create two columns for video display
106
+ col1, col2 = st.columns([0.5, 0.5])
107
+ with col1:
108
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
109
+ st.video(temp_file.name)
110
+
111
+ with col2:
112
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
113
+ filter = st.sidebar.radio('Colorize your video with:',
114
+ ['Original', 'ECCV 16', 'SIGGRAPH 17'])
115
+ if filter == 'ECCV 16':
116
+ colorizer = eccv16(pretrained=True).eval()
117
+ elif filter == 'SIGGRAPH 17':
118
+ colorizer = siggraph17(pretrained=True).eval()
119
+
120
+ if filter != 'Original':
121
+ with st.spinner("Colorizing frames..."):
122
+ # Colorize video frames and store in a list
123
+ output_frames = []
124
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
125
+ progress_bar = st.empty()
126
+
127
+ start_time = time.time()
128
+ for i in tqdm(range(total_frames), unit='frame', desc="Progress"):
129
+ ret, frame = video.read()
130
+ if not ret:
131
+ break
132
+
133
+ colorized_frame = colorize_frame(frame, colorizer)
134
+ output_frames.append((colorized_frame * 255).astype(np.uint8))
135
+
136
+ elapsed_time = time.time() - start_time
137
+ frames_completed = len(output_frames)
138
+ frames_remaining = total_frames - frames_completed
139
+ time_remaining = (frames_remaining / frames_completed) * elapsed_time
140
+
141
+ progress_bar.progress(frames_completed / total_frames)
142
+
143
+ if frames_completed < total_frames:
144
+ progress_bar.text(f"Time Remaining: {format_time(time_remaining)}")
145
+ else:
146
+ progress_bar.empty()
147
+
148
+ with st.spinner("Merging frames to video..."):
149
+ frame_size = output_frames[0].shape[:2]
150
+ output_filename = "output.mp4"
151
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
152
+ out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0]))
153
+
154
+ # Display the colorized video using st.video
155
+ for frame in output_frames:
156
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
157
+
158
+ out.write(frame_bgr)
159
+
160
+ out.release()
161
+
162
+ # Convert the output video to a format compatible with Streamlit
163
+ converted_filename = "converted_output.mp4"
164
+ clip = mp.VideoFileClip(output_filename)
165
+ clip.write_videofile(converted_filename, codec="libx264")
166
+
167
+ # Display the converted video using st.video()
168
+ st.video(converted_filename)
169
+
170
+ # Add a download button for the colorized video
171
+ st.download_button(
172
+ label="Download Colorized Video",
173
+ data=open(converted_filename, "rb").read(),
174
+ file_name="colorized_video.mp4"
175
+ )
176
+
177
+ # Close and delete the temporary file after processing
178
+ video.release()
179
+ temp_file.close()
180
+
181
+ # Add a feedback section in the sidebar
182
+ st.sidebar.title(' ') # Used to create some space between the filter widget and the comments section
183
+ st.sidebar.markdown(' ') # Used to create some space between the filter widget and the comments section
184
+ st.sidebar.subheader('Please help us improve!')
185
+ with st.sidebar.form(key='columns_in_form',
186
+ clear_on_submit=True): # set clear_on_submit=True so that the form will be reset/cleared once
187
+ # it's submitted
188
+ rating = st.slider("Please rate the app", min_value=1, max_value=5, value=3,
189
+ help='Drag the slider to rate the app. This is a 1-5 rating scale where 5 is the highest rating')
190
+ text = st.text_input(label='Please leave your feedback here')
191
+ submitted = st.form_submit_button('Submit')
192
+ if submitted:
193
+ st.write('Thanks for your feedback!')
194
+ st.markdown('Your Rating:')
195
+ st.markdown(rating)
196
+ st.markdown('Your Feedback:')
197
+ st.markdown(text)