mcding
commited on
Commit
•
23c4bb7
1
Parent(s):
32afe8e
fix lfs issue
Browse files- app.py +137 -49
- attacked_image.png +0 -0
- kit/__init__.py +1 -1
- kit/metrics/__init__.py +0 -2
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth +3 -0
- kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth +3 -0
- kit/metrics/lpips/weights/v0.0/alex.pth +3 -0
- kit/metrics/lpips/weights/v0.0/squeeze.pth +3 -0
- kit/metrics/lpips/weights/v0.0/vgg.pth +3 -0
- kit/metrics/lpips/weights/v0.1/alex.pth +3 -0
- kit/metrics/lpips/weights/v0.1/squeeze.pth +3 -0
- kit/metrics/lpips/weights/v0.1/vgg.pth +3 -0
- kit/metrics/perceptual.py +1 -30
- kit/metrics/watson/__init__.py +0 -4
- kit/metrics/watson/color_wrapper.py +0 -103
- kit/metrics/watson/dct2d.py +0 -105
- kit/metrics/watson/deep_loss.py +0 -307
- kit/metrics/watson/loss_provider.py +0 -180
- kit/metrics/watson/rfft2d.py +0 -87
- kit/metrics/watson/shift_wrapper.py +0 -51
- kit/metrics/watson/ssim.py +0 -95
- kit/metrics/watson/watson.py +0 -123
- kit/metrics/watson/watson_fft.py +0 -139
- kit/metrics/watson/watson_vgg.py +0 -202
- kit/models/stable_signature.onnx +3 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
|
4 |
import numpy as np
|
5 |
import json
|
6 |
import redis
|
7 |
-
from PIL import Image
|
8 |
-
import time
|
9 |
import plotly.graph_objects as go
|
10 |
from datetime import datetime
|
11 |
from kit import compute_performance, compute_quality
|
@@ -13,6 +10,35 @@ import dotenv
|
|
13 |
|
14 |
dotenv.load_dotenv()
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Connect to Redis
|
18 |
redis_client = redis.Redis(
|
@@ -39,7 +65,10 @@ def get_submissions_from_redis():
|
|
39 |
return [json.loads(submission) for submission in submissions]
|
40 |
|
41 |
|
42 |
-
def
|
|
|
|
|
|
|
43 |
names = [sub["name"] for sub in submissions]
|
44 |
performances = [float(sub["performance"]) for sub in submissions]
|
45 |
qualities = [float(sub["quality"]) for sub in submissions]
|
@@ -47,16 +76,26 @@ def update_leaderboard(submissions):
|
|
47 |
# Create scatter plot
|
48 |
fig = go.Figure()
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
)
|
59 |
-
)
|
60 |
|
61 |
# Add circles
|
62 |
circle_radii = np.linspace(0, 1, 5)
|
@@ -76,13 +115,17 @@ def update_leaderboard(submissions):
|
|
76 |
|
77 |
# Update layout
|
78 |
fig.update_layout(
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
)
|
87 |
|
88 |
return fig
|
@@ -91,17 +134,15 @@ def update_leaderboard(submissions):
|
|
91 |
def process_submission(name, image):
|
92 |
original_image = Image.open("./image.png")
|
93 |
progress = gr.Progress()
|
94 |
-
progress(0, desc="
|
95 |
-
time.sleep(0.5)
|
96 |
-
progress(0.1, desc="Decoding")
|
97 |
performance = compute_performance(image)
|
98 |
-
progress(0.
|
99 |
quality = compute_quality(image, original_image)
|
100 |
-
progress(0
|
101 |
save_to_redis(name, performance, quality)
|
102 |
|
103 |
submissions = get_submissions_from_redis()
|
104 |
-
leaderboard_plot =
|
105 |
|
106 |
# Calculate rank
|
107 |
distances = [
|
@@ -111,13 +152,15 @@ def process_submission(name, image):
|
|
111 |
rank = (
|
112 |
sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1
|
113 |
)
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
121 |
|
122 |
|
123 |
def upload_and_evaluate(name, image):
|
@@ -129,61 +172,105 @@ def upload_and_evaluate(name, image):
|
|
129 |
|
130 |
|
131 |
def create_interface():
|
132 |
-
with gr.Blocks() as demo:
|
133 |
gr.Markdown(
|
134 |
"""
|
135 |
-
# Erasing the Invisible
|
|
|
136 |
"""
|
137 |
)
|
138 |
|
139 |
-
with gr.Tabs() as tabs:
|
140 |
with gr.Tab("Original Watermarked Image", id="download"):
|
|
|
|
|
|
|
|
|
|
|
141 |
with gr.Column():
|
142 |
original_image = gr.Image(
|
143 |
value="./image.png",
|
|
|
144 |
label="Original Watermarked Image",
|
145 |
show_label=True,
|
146 |
height=512,
|
|
|
|
|
147 |
show_download_button=False,
|
148 |
show_share_button=False,
|
149 |
show_fullscreen_button=False,
|
|
|
|
|
150 |
)
|
151 |
with gr.Row():
|
152 |
-
gr.DownloadButton(
|
153 |
-
"Download Watermarked Image",
|
|
|
|
|
|
|
|
|
|
|
154 |
)
|
155 |
-
submit_btn = gr.Button("Submit Your Removal", scale=3)
|
156 |
|
157 |
-
with gr.Tab(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
with gr.Column():
|
159 |
uploaded_image = gr.Image(
|
160 |
label="Your Watermark Removed Image",
|
|
|
161 |
show_label=True,
|
162 |
height=512,
|
|
|
|
|
163 |
type="pil",
|
164 |
show_download_button=False,
|
165 |
show_share_button=False,
|
166 |
show_fullscreen_button=False,
|
|
|
|
|
|
|
167 |
)
|
168 |
with gr.Row():
|
169 |
name_input = gr.Textbox(
|
170 |
-
label="Your Name",
|
171 |
)
|
172 |
upload_btn = gr.Button("Upload and Evaluate")
|
173 |
|
174 |
-
with gr.Tab(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
with gr.Column():
|
176 |
leaderboard_plot = gr.Plot(
|
177 |
-
|
|
|
|
|
178 |
)
|
179 |
with gr.Row():
|
180 |
-
rank_output = gr.
|
181 |
name_output = gr.Textbox(label="Your Name")
|
182 |
-
performance_output = gr.
|
183 |
-
label="Watermark Performance
|
|
|
|
|
|
|
184 |
)
|
185 |
-
|
186 |
-
label="
|
187 |
)
|
188 |
|
189 |
submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
|
@@ -197,13 +284,14 @@ def create_interface():
|
|
197 |
name_output,
|
198 |
performance_output,
|
199 |
quality_output,
|
|
|
200 |
],
|
201 |
)
|
202 |
|
203 |
demo.load(
|
204 |
lambda: [
|
205 |
gr.Image(value="./image.png", height=512, width=512),
|
206 |
-
gr.Plot(
|
207 |
],
|
208 |
outputs=[original_image, leaderboard_plot],
|
209 |
)
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
3 |
import numpy as np
|
4 |
import json
|
5 |
import redis
|
|
|
|
|
6 |
import plotly.graph_objects as go
|
7 |
from datetime import datetime
|
8 |
from kit import compute_performance, compute_quality
|
|
|
10 |
|
11 |
dotenv.load_dotenv()
|
12 |
|
13 |
+
CSS = """
|
14 |
+
.tabs button{
|
15 |
+
font-size: 24px;
|
16 |
+
}
|
17 |
+
#download_btn {
|
18 |
+
height: 91.6px;
|
19 |
+
}
|
20 |
+
#submit_btn {
|
21 |
+
height: 91.6px;
|
22 |
+
}
|
23 |
+
#original_image {
|
24 |
+
display: block;
|
25 |
+
margin-left: auto;
|
26 |
+
margin-right: auto;
|
27 |
+
}
|
28 |
+
#uploaded_image {
|
29 |
+
display: block;
|
30 |
+
margin-left: auto;
|
31 |
+
margin-right: auto;
|
32 |
+
}
|
33 |
+
#leaderboard_plot {
|
34 |
+
display: block;
|
35 |
+
margin-left: auto;
|
36 |
+
margin-right: auto;
|
37 |
+
width: 512px; /* Adjust width as needed */
|
38 |
+
height: 512px; /* Adjust height as needed */
|
39 |
+
}
|
40 |
+
"""
|
41 |
+
|
42 |
|
43 |
# Connect to Redis
|
44 |
redis_client = redis.Redis(
|
|
|
65 |
return [json.loads(submission) for submission in submissions]
|
66 |
|
67 |
|
68 |
+
def update_plot(
|
69 |
+
submissions,
|
70 |
+
current_name=None,
|
71 |
+
):
|
72 |
names = [sub["name"] for sub in submissions]
|
73 |
performances = [float(sub["performance"]) for sub in submissions]
|
74 |
qualities = [float(sub["quality"]) for sub in submissions]
|
|
|
76 |
# Create scatter plot
|
77 |
fig = go.Figure()
|
78 |
|
79 |
+
for name, quality, performance in zip(names, qualities, performances):
|
80 |
+
if name == current_name:
|
81 |
+
marker = dict(symbol="star", size=15, color="blue")
|
82 |
+
elif name.startswith("Baseline: "):
|
83 |
+
marker = dict(symbol="square", size=10, color="grey")
|
84 |
+
else:
|
85 |
+
marker = dict(symbol="circle", size=10, color="green")
|
86 |
+
|
87 |
+
fig.add_trace(
|
88 |
+
go.Scatter(
|
89 |
+
x=[quality],
|
90 |
+
y=[performance],
|
91 |
+
mode="markers+text",
|
92 |
+
text=[name],
|
93 |
+
textposition="top center",
|
94 |
+
name=name,
|
95 |
+
marker=marker,
|
96 |
+
hovertemplate=f"{'Name: ' + name if not name.startswith('Baseline: ') else name}<br>(Performance, Quality) = ({performance:.3f}, {quality:.3f})",
|
97 |
+
)
|
98 |
)
|
|
|
99 |
|
100 |
# Add circles
|
101 |
circle_radii = np.linspace(0, 1, 5)
|
|
|
115 |
|
116 |
# Update layout
|
117 |
fig.update_layout(
|
118 |
+
xaxis_title="Image Quality Degredation",
|
119 |
+
yaxis_title="Watermark Detection Performance",
|
120 |
+
xaxis=dict(
|
121 |
+
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
|
122 |
+
),
|
123 |
+
yaxis=dict(
|
124 |
+
range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
|
125 |
+
),
|
126 |
+
width=512,
|
127 |
+
height=512,
|
128 |
+
showlegend=False, # Remove legend
|
129 |
)
|
130 |
|
131 |
return fig
|
|
|
134 |
def process_submission(name, image):
|
135 |
original_image = Image.open("./image.png")
|
136 |
progress = gr.Progress()
|
137 |
+
progress(0, desc="Detecting Watermark")
|
|
|
|
|
138 |
performance = compute_performance(image)
|
139 |
+
progress(0.4, desc="Evaluating Image Quality")
|
140 |
quality = compute_quality(image, original_image)
|
141 |
+
progress(1.0, desc="Uploading Results")
|
142 |
save_to_redis(name, performance, quality)
|
143 |
|
144 |
submissions = get_submissions_from_redis()
|
145 |
+
leaderboard_plot = update_plot(submissions, current_name=name)
|
146 |
|
147 |
# Calculate rank
|
148 |
distances = [
|
|
|
152 |
rank = (
|
153 |
sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1
|
154 |
)
|
155 |
+
gr.Info(f"You ranked {rank} out of {len(submissions)}!")
|
156 |
+
return (
|
157 |
+
leaderboard_plot,
|
158 |
+
f"{rank} out of {len(submissions)}",
|
159 |
+
name,
|
160 |
+
f"{performance:.3f}",
|
161 |
+
f"{quality:.3f}",
|
162 |
+
f"{np.sqrt(quality**2 + performance**2):.3f}",
|
163 |
+
)
|
164 |
|
165 |
|
166 |
def upload_and_evaluate(name, image):
|
|
|
172 |
|
173 |
|
174 |
def create_interface():
|
175 |
+
with gr.Blocks(css=CSS) as demo:
|
176 |
gr.Markdown(
|
177 |
"""
|
178 |
+
# Erasing the Invisible Demo
|
179 |
+
TODO: Improve title and add description, add icon.jpg, also improve configs in README.md
|
180 |
"""
|
181 |
)
|
182 |
|
183 |
+
with gr.Tabs(elem_classes=["tabs"]) as tabs:
|
184 |
with gr.Tab("Original Watermarked Image", id="download"):
|
185 |
+
gr.Markdown(
|
186 |
+
"""
|
187 |
+
TODO: Add descriptions
|
188 |
+
"""
|
189 |
+
)
|
190 |
with gr.Column():
|
191 |
original_image = gr.Image(
|
192 |
value="./image.png",
|
193 |
+
format="png",
|
194 |
label="Original Watermarked Image",
|
195 |
show_label=True,
|
196 |
height=512,
|
197 |
+
width=512,
|
198 |
+
type="filepath",
|
199 |
show_download_button=False,
|
200 |
show_share_button=False,
|
201 |
show_fullscreen_button=False,
|
202 |
+
container=True,
|
203 |
+
elem_id="original_image",
|
204 |
)
|
205 |
with gr.Row():
|
206 |
+
download_btn = gr.DownloadButton(
|
207 |
+
"Download Watermarked Image",
|
208 |
+
value="./image.png",
|
209 |
+
elem_id="download_btn",
|
210 |
+
)
|
211 |
+
submit_btn = gr.Button(
|
212 |
+
"Submit Your Removal", elem_id="submit_btn"
|
213 |
)
|
|
|
214 |
|
215 |
+
with gr.Tab(
|
216 |
+
"Submit Watermark Removed Image",
|
217 |
+
id="submit",
|
218 |
+
elem_classes="gr-tab-header",
|
219 |
+
):
|
220 |
+
gr.Markdown(
|
221 |
+
"""
|
222 |
+
TODO: Add descriptions
|
223 |
+
"""
|
224 |
+
)
|
225 |
with gr.Column():
|
226 |
uploaded_image = gr.Image(
|
227 |
label="Your Watermark Removed Image",
|
228 |
+
format="png",
|
229 |
show_label=True,
|
230 |
height=512,
|
231 |
+
width=512,
|
232 |
+
sources=["upload"],
|
233 |
type="pil",
|
234 |
show_download_button=False,
|
235 |
show_share_button=False,
|
236 |
show_fullscreen_button=False,
|
237 |
+
container=True,
|
238 |
+
placeholder="Upload your watermark removed image",
|
239 |
+
elem_id="uploaded_image",
|
240 |
)
|
241 |
with gr.Row():
|
242 |
name_input = gr.Textbox(
|
243 |
+
label="Your Name", placeholder="Anonymous"
|
244 |
)
|
245 |
upload_btn = gr.Button("Upload and Evaluate")
|
246 |
|
247 |
+
with gr.Tab(
|
248 |
+
"Evaluation Results and Your Ranking",
|
249 |
+
id="leaderboard",
|
250 |
+
elem_classes="gr-tab-header",
|
251 |
+
):
|
252 |
+
gr.Markdown(
|
253 |
+
"""
|
254 |
+
TODO: Add descriptions
|
255 |
+
"""
|
256 |
+
)
|
257 |
with gr.Column():
|
258 |
leaderboard_plot = gr.Plot(
|
259 |
+
value=update_plot(get_submissions_from_redis()),
|
260 |
+
show_label=False,
|
261 |
+
elem_id="leaderboard_plot",
|
262 |
)
|
263 |
with gr.Row():
|
264 |
+
rank_output = gr.Textbox(label="Your Ranking")
|
265 |
name_output = gr.Textbox(label="Your Name")
|
266 |
+
performance_output = gr.Textbox(
|
267 |
+
label="Watermark Performance (lower is better)"
|
268 |
+
)
|
269 |
+
quality_output = gr.Textbox(
|
270 |
+
label="Quality Degredation (lower is better)"
|
271 |
)
|
272 |
+
overall_output = gr.Textbox(
|
273 |
+
label="Overall Score (lower is better)"
|
274 |
)
|
275 |
|
276 |
submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
|
|
|
284 |
name_output,
|
285 |
performance_output,
|
286 |
quality_output,
|
287 |
+
overall_output,
|
288 |
],
|
289 |
)
|
290 |
|
291 |
demo.load(
|
292 |
lambda: [
|
293 |
gr.Image(value="./image.png", height=512, width=512),
|
294 |
+
gr.Plot(update_plot(get_submissions_from_redis())),
|
295 |
],
|
296 |
outputs=[original_image, leaderboard_plot],
|
297 |
)
|
attacked_image.png
DELETED
Binary file (155 kB)
|
|
kit/__init__.py
CHANGED
@@ -81,7 +81,7 @@ def compute_quality(attacked_image, clean_image, quiet=True):
|
|
81 |
|
82 |
# Compress the image
|
83 |
buffer = io.BytesIO()
|
84 |
-
attacked_image.save(buffer, format="JPEG", quality=
|
85 |
buffer.seek(0)
|
86 |
|
87 |
# Update attacked_image with the compressed version
|
|
|
81 |
|
82 |
# Compress the image
|
83 |
buffer = io.BytesIO()
|
84 |
+
attacked_image.save(buffer, format="JPEG", quality=95)
|
85 |
buffer.seek(0)
|
86 |
|
87 |
# Update attacked_image with the compressed version
|
kit/metrics/__init__.py
CHANGED
@@ -13,9 +13,7 @@ from .image import (
|
|
13 |
from .perceptual import (
|
14 |
load_perceptual_models,
|
15 |
compute_lpips,
|
16 |
-
compute_watson,
|
17 |
compute_lpips_repeated,
|
18 |
-
compute_watson_repeated,
|
19 |
compute_perceptual_metric_repeated,
|
20 |
)
|
21 |
from .aesthetics import (
|
|
|
13 |
from .perceptual import (
|
14 |
load_perceptual_models,
|
15 |
compute_lpips,
|
|
|
16 |
compute_lpips_repeated,
|
|
|
17 |
compute_perceptual_metric_repeated,
|
18 |
)
|
19 |
from .aesthetics import (
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39a5d014670226d52c408e0dfec840b7626d80a73d003a6a144caafd5e02d031
|
3 |
+
size 19423219
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc48a8a2315cfdbc7bb8278be55f645e8a995e1a2fa234baec5eb41c4d33e070
|
3 |
+
size 17850319
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4a9481fdbce5ff02b252bcb25109b9f3b29841289fadf7e79e884d59f9357d5
|
3 |
+
size 16801743
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19b016304f54ae866e27f1eb498c0861f704958e7c37693adc5ce094e63904a8
|
3 |
+
size 19423099
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03603eee1864c2e5e97ef7079229609653db5b10594ca8b1de9e541d838cae9c
|
3 |
+
size 17850199
|
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eb7fe561369ab6c7dad34b9316a56d2c6070582f0323656148e1107a242cd666
|
3 |
+
size 16801623
|
kit/metrics/lpips/weights/v0.0/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
+
size 5455
|
kit/metrics/lpips/weights/v0.0/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
+
size 10057
|
kit/metrics/lpips/weights/v0.0/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
+
size 6735
|
kit/metrics/lpips/weights/v0.1/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|
kit/metrics/lpips/weights/v0.1/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
+
size 10811
|
kit/metrics/lpips/weights/v0.1/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
kit/metrics/perceptual.py
CHANGED
@@ -2,7 +2,6 @@ import torch
|
|
2 |
from PIL import Image
|
3 |
from torchvision import transforms
|
4 |
from .lpips import LPIPS
|
5 |
-
from .watson import LossProvider
|
6 |
|
7 |
|
8 |
# Normalize image tensors
|
@@ -33,19 +32,10 @@ def to_tensor(images, norm_type="naive"):
|
|
33 |
|
34 |
|
35 |
def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
|
36 |
-
assert metric_name in ["lpips"
|
37 |
if metric_name == "lpips":
|
38 |
assert mode in ["vgg", "alex"]
|
39 |
perceptual_model = LPIPS(net=mode).to(device)
|
40 |
-
elif metric_name == "watson":
|
41 |
-
assert mode in ["vgg", "dft", "fft"]
|
42 |
-
perceptual_model = (
|
43 |
-
LossProvider()
|
44 |
-
.get_loss_function(
|
45 |
-
"Watson-" + mode, colorspace="RGB", pretrained=True, reduction="none"
|
46 |
-
)
|
47 |
-
.to(device)
|
48 |
-
)
|
49 |
else:
|
50 |
assert False
|
51 |
return perceptual_model
|
@@ -65,12 +55,6 @@ def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
|
|
65 |
return compute_metric(image1, image2, perceptual_model, device)
|
66 |
|
67 |
|
68 |
-
# Compute Watson distance between two images
|
69 |
-
def compute_watson(image1, image2, mode="dft", device=torch.device("cuda")):
|
70 |
-
perceptual_model = load_perceptual_models("watson", mode, device)
|
71 |
-
return compute_metric(image1, image2, perceptual_model, device)
|
72 |
-
|
73 |
-
|
74 |
# Compute metrics between pairs of images
|
75 |
def compute_perceptual_metric_repeated(
|
76 |
images1,
|
@@ -107,16 +91,3 @@ def compute_lpips_repeated(
|
|
107 |
return compute_perceptual_metric_repeated(
|
108 |
images1, images2, "lpips", mode, model, device
|
109 |
)
|
110 |
-
|
111 |
-
|
112 |
-
# Compute Watson distance between pairs of images
|
113 |
-
def compute_watson_repeated(
|
114 |
-
images1,
|
115 |
-
images2,
|
116 |
-
mode="dft",
|
117 |
-
model=None,
|
118 |
-
device=torch.device("cuda"),
|
119 |
-
):
|
120 |
-
return compute_perceptual_metric_repeated(
|
121 |
-
images1, images2, "watson", mode, model, device
|
122 |
-
)
|
|
|
2 |
from PIL import Image
|
3 |
from torchvision import transforms
|
4 |
from .lpips import LPIPS
|
|
|
5 |
|
6 |
|
7 |
# Normalize image tensors
|
|
|
32 |
|
33 |
|
34 |
def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
|
35 |
+
assert metric_name in ["lpips"]
|
36 |
if metric_name == "lpips":
|
37 |
assert mode in ["vgg", "alex"]
|
38 |
perceptual_model = LPIPS(net=mode).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
else:
|
40 |
assert False
|
41 |
return perceptual_model
|
|
|
55 |
return compute_metric(image1, image2, perceptual_model, device)
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# Compute metrics between pairs of images
|
59 |
def compute_perceptual_metric_repeated(
|
60 |
images1,
|
|
|
91 |
return compute_perceptual_metric_repeated(
|
92 |
images1, images2, "lpips", mode, model, device
|
93 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
From https://github.com/facebookresearch/stable_signature
|
3 |
-
"""
|
4 |
-
from .loss_provider import LossProvider
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/color_wrapper.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
|
6 |
-
class RGB2YCbCr(nn.Module):
|
7 |
-
def __init__(self):
|
8 |
-
super().__init__()
|
9 |
-
transf = torch.tensor(
|
10 |
-
[[0.299, 0.587, 0.114], [-0.1687, -0.3313, 0.5], [0.5, -0.4187, -0.0813]]
|
11 |
-
).transpose(0, 1)
|
12 |
-
self.transform = nn.Parameter(transf, requires_grad=False)
|
13 |
-
bias = torch.tensor([0, 0.5, 0.5])
|
14 |
-
self.bias = nn.Parameter(bias, requires_grad=False)
|
15 |
-
|
16 |
-
def forward(self, rgb):
|
17 |
-
N, C, H, W = rgb.shape
|
18 |
-
assert C == 3
|
19 |
-
rgb = rgb.transpose(1, 3)
|
20 |
-
cbcr = torch.matmul(rgb, self.transform)
|
21 |
-
cbcr += self.bias
|
22 |
-
return cbcr.transpose(1, 3)
|
23 |
-
|
24 |
-
|
25 |
-
class ColorWrapper(nn.Module):
|
26 |
-
"""
|
27 |
-
Extension for single-channel loss to work on color images
|
28 |
-
"""
|
29 |
-
|
30 |
-
def __init__(self, lossclass, args, kwargs, trainable=False):
|
31 |
-
"""
|
32 |
-
Parameters:
|
33 |
-
lossclass: class of the individual loss functions
|
34 |
-
trainable: bool, if True parameters of the loss are trained.
|
35 |
-
args: tuple, arguments for instantiation of loss fun
|
36 |
-
kwargs: dict, key word arguments for instantiation of loss fun
|
37 |
-
"""
|
38 |
-
super().__init__()
|
39 |
-
|
40 |
-
# submodules
|
41 |
-
self.add_module("to_YCbCr", RGB2YCbCr())
|
42 |
-
self.add_module("ly", lossclass(*args, **kwargs))
|
43 |
-
self.add_module("lcb", lossclass(*args, **kwargs))
|
44 |
-
self.add_module("lcr", lossclass(*args, **kwargs))
|
45 |
-
|
46 |
-
# weights
|
47 |
-
self.w_tild = nn.Parameter(torch.zeros(3), requires_grad=trainable)
|
48 |
-
|
49 |
-
@property
|
50 |
-
def w(self):
|
51 |
-
return F.softmax(self.w_tild, dim=0)
|
52 |
-
|
53 |
-
def forward(self, input, target):
|
54 |
-
# convert color space
|
55 |
-
input = self.to_YCbCr(input)
|
56 |
-
target = self.to_YCbCr(target)
|
57 |
-
|
58 |
-
ly = self.ly(input[:, [0], :, :], target[:, [0], :, :])
|
59 |
-
lcb = self.lcb(input[:, [1], :, :], target[:, [1], :, :])
|
60 |
-
lcr = self.lcr(input[:, [2], :, :], target[:, [2], :, :])
|
61 |
-
|
62 |
-
w = self.w
|
63 |
-
|
64 |
-
return ly * w[0] + lcb * w[1] + lcr * w[2]
|
65 |
-
|
66 |
-
|
67 |
-
class GreyscaleWrapper(nn.Module):
|
68 |
-
"""
|
69 |
-
Maps 3 channel RGB or 1 channel greyscale input to 3 greyscale channels
|
70 |
-
"""
|
71 |
-
|
72 |
-
def __init__(self, lossclass, args, kwargs):
|
73 |
-
"""
|
74 |
-
Parameters:
|
75 |
-
lossclass: class of the individual loss function
|
76 |
-
args: tuple, arguments for instantiation of loss fun
|
77 |
-
kwargs: dict, key word arguments for instantiation of loss fun
|
78 |
-
"""
|
79 |
-
super().__init__()
|
80 |
-
|
81 |
-
# submodules
|
82 |
-
self.add_module("loss", lossclass(*args, **kwargs))
|
83 |
-
|
84 |
-
def to_greyscale(self, tensor):
|
85 |
-
return (
|
86 |
-
tensor[:, [0], :, :] * 0.3
|
87 |
-
+ tensor[:, [1], :, :] * 0.59
|
88 |
-
+ tensor[:, [2], :, :] * 0.11
|
89 |
-
)
|
90 |
-
|
91 |
-
def forward(self, input, target):
|
92 |
-
(N, C, X, Y) = input.size()
|
93 |
-
|
94 |
-
if N == 3:
|
95 |
-
# convert input to greyscale
|
96 |
-
input = self.to_greyscale(input)
|
97 |
-
target = self.to_greyscale(target)
|
98 |
-
|
99 |
-
# input in now greyscale, expand to 3 channels
|
100 |
-
input = input.expand(N, 3, X, Y)
|
101 |
-
target = target.expand(N, 3, X, Y)
|
102 |
-
|
103 |
-
return self.loss.forward(input, target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/dct2d.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
class Dct2d(nn.Module):
|
8 |
-
"""
|
9 |
-
Blockwhise 2D DCT
|
10 |
-
"""
|
11 |
-
|
12 |
-
def __init__(self, blocksize=8, interleaving=False):
|
13 |
-
"""
|
14 |
-
Parameters:
|
15 |
-
blocksize: int, size of the Blocks for discrete cosine transform
|
16 |
-
interleaving: bool, should the blocks interleave?
|
17 |
-
"""
|
18 |
-
super().__init__() # call super constructor
|
19 |
-
|
20 |
-
self.blocksize = blocksize
|
21 |
-
self.interleaving = interleaving
|
22 |
-
|
23 |
-
if interleaving:
|
24 |
-
self.stride = self.blocksize // 2
|
25 |
-
else:
|
26 |
-
self.stride = self.blocksize
|
27 |
-
|
28 |
-
# precompute DCT weight matrix
|
29 |
-
A = np.zeros((blocksize, blocksize))
|
30 |
-
for i in range(blocksize):
|
31 |
-
c_i = 1 / np.sqrt(2) if i == 0 else 1.0
|
32 |
-
for n in range(blocksize):
|
33 |
-
A[i, n] = (
|
34 |
-
np.sqrt(2 / blocksize)
|
35 |
-
* c_i
|
36 |
-
* np.cos((2 * n + 1) / (blocksize * 2) * i * np.pi)
|
37 |
-
)
|
38 |
-
|
39 |
-
# set up conv layer
|
40 |
-
self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32), requires_grad=False)
|
41 |
-
self.unfold = torch.nn.Unfold(
|
42 |
-
kernel_size=blocksize, padding=0, stride=self.stride
|
43 |
-
)
|
44 |
-
return
|
45 |
-
|
46 |
-
def forward(self, x):
|
47 |
-
"""
|
48 |
-
performs 2D blockwhise DCT
|
49 |
-
|
50 |
-
Parameters:
|
51 |
-
x: tensor of dimension (N, 1, h, w)
|
52 |
-
|
53 |
-
Return:
|
54 |
-
tensor of dimension (N, k, blocksize, blocksize)
|
55 |
-
where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients
|
56 |
-
"""
|
57 |
-
|
58 |
-
(N, C, H, W) = x.shape
|
59 |
-
assert C == 1, "DCT is only implemented for a single channel"
|
60 |
-
assert H >= self.blocksize, "Input too small for blocksize"
|
61 |
-
assert W >= self.blocksize, "Input too small for blocksize"
|
62 |
-
assert (H % self.stride == 0) and (
|
63 |
-
W % self.stride == 0
|
64 |
-
), "FFT is only for dimensions divisible by the blocksize"
|
65 |
-
|
66 |
-
# unfold to blocks
|
67 |
-
x = self.unfold(x)
|
68 |
-
# now shape (N, blocksize**2, k)
|
69 |
-
(N, _, k) = x.shape
|
70 |
-
x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2)
|
71 |
-
# now shape (N, #k, blocksize, blocksize)
|
72 |
-
# perform DCT
|
73 |
-
coeff = self.A.matmul(x).matmul(self.A.transpose(0, 1))
|
74 |
-
|
75 |
-
return coeff
|
76 |
-
|
77 |
-
def inverse(self, coeff, output_shape):
|
78 |
-
"""
|
79 |
-
performs 2D blockwhise iDCT
|
80 |
-
|
81 |
-
Parameters:
|
82 |
-
coeff: tensor of dimension (N, k, blocksize, blocksize)
|
83 |
-
where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients
|
84 |
-
output_shape: (h, w) dimensions of the reconstructed image
|
85 |
-
|
86 |
-
Return:
|
87 |
-
tensor of dimension (N, 1, h, w)
|
88 |
-
"""
|
89 |
-
if self.interleaving:
|
90 |
-
raise Exception(
|
91 |
-
"Inverse block DCT is not implemented for interleaving blocks!"
|
92 |
-
)
|
93 |
-
|
94 |
-
# perform iDCT
|
95 |
-
x = self.A.transpose(0, 1).matmul(coeff).matmul(self.A)
|
96 |
-
(N, k, _, _) = x.shape
|
97 |
-
x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k)
|
98 |
-
x = F.fold(
|
99 |
-
x,
|
100 |
-
output_size=(output_shape[-2], output_shape[-1]),
|
101 |
-
kernel_size=self.blocksize,
|
102 |
-
padding=0,
|
103 |
-
stride=self.blocksize,
|
104 |
-
)
|
105 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/deep_loss.py
DELETED
@@ -1,307 +0,0 @@
|
|
1 |
-
# Deeploss function from Zhang et al. (2018)
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
from torchvision import models
|
5 |
-
from collections import namedtuple
|
6 |
-
|
7 |
-
|
8 |
-
class NetLinLayer(nn.Module):
|
9 |
-
"""A single linear layer which does a 1x1 conv"""
|
10 |
-
|
11 |
-
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
12 |
-
super(NetLinLayer, self).__init__()
|
13 |
-
|
14 |
-
layers = (
|
15 |
-
[
|
16 |
-
nn.Dropout(),
|
17 |
-
]
|
18 |
-
if (use_dropout)
|
19 |
-
else [
|
20 |
-
nn.Dropout(p=0.0),
|
21 |
-
]
|
22 |
-
)
|
23 |
-
layers += [
|
24 |
-
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
25 |
-
]
|
26 |
-
self.model = nn.Sequential(*layers)
|
27 |
-
|
28 |
-
|
29 |
-
def normalize_tensor(in_feat, eps=1e-10):
|
30 |
-
# norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3]).repeat(1,in_feat.size()[1],1,1)
|
31 |
-
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
|
32 |
-
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
|
33 |
-
)
|
34 |
-
return in_feat / (norm_factor.expand_as(in_feat) + eps)
|
35 |
-
|
36 |
-
|
37 |
-
class vgg16(torch.nn.Module):
|
38 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
39 |
-
super(vgg16, self).__init__()
|
40 |
-
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
41 |
-
self.slice1 = torch.nn.Sequential()
|
42 |
-
self.slice2 = torch.nn.Sequential()
|
43 |
-
self.slice3 = torch.nn.Sequential()
|
44 |
-
self.slice4 = torch.nn.Sequential()
|
45 |
-
self.slice5 = torch.nn.Sequential()
|
46 |
-
self.N_slices = 5
|
47 |
-
for x in range(4):
|
48 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
49 |
-
for x in range(4, 9):
|
50 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
51 |
-
for x in range(9, 16):
|
52 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
53 |
-
for x in range(16, 23):
|
54 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
55 |
-
for x in range(23, 30):
|
56 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
57 |
-
if not requires_grad:
|
58 |
-
for param in self.parameters():
|
59 |
-
param.requires_grad = False
|
60 |
-
|
61 |
-
def forward(self, X):
|
62 |
-
h = self.slice1(X)
|
63 |
-
h_relu1_2 = h
|
64 |
-
h = self.slice2(h)
|
65 |
-
h_relu2_2 = h
|
66 |
-
h = self.slice3(h)
|
67 |
-
h_relu3_3 = h
|
68 |
-
h = self.slice4(h)
|
69 |
-
h_relu4_3 = h
|
70 |
-
h = self.slice5(h)
|
71 |
-
h_relu5_3 = h
|
72 |
-
vgg_outputs = namedtuple(
|
73 |
-
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
74 |
-
)
|
75 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
76 |
-
return out
|
77 |
-
|
78 |
-
|
79 |
-
class squeezenet(torch.nn.Module):
|
80 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
81 |
-
super(squeezenet, self).__init__()
|
82 |
-
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
|
83 |
-
self.slice1 = torch.nn.Sequential()
|
84 |
-
self.slice2 = torch.nn.Sequential()
|
85 |
-
self.slice3 = torch.nn.Sequential()
|
86 |
-
self.slice4 = torch.nn.Sequential()
|
87 |
-
self.slice5 = torch.nn.Sequential()
|
88 |
-
self.slice6 = torch.nn.Sequential()
|
89 |
-
self.slice7 = torch.nn.Sequential()
|
90 |
-
self.N_slices = 7
|
91 |
-
for x in range(2):
|
92 |
-
self.slice1.add_module(str(x), pretrained_features[x])
|
93 |
-
for x in range(2, 5):
|
94 |
-
self.slice2.add_module(str(x), pretrained_features[x])
|
95 |
-
for x in range(5, 8):
|
96 |
-
self.slice3.add_module(str(x), pretrained_features[x])
|
97 |
-
for x in range(8, 10):
|
98 |
-
self.slice4.add_module(str(x), pretrained_features[x])
|
99 |
-
for x in range(10, 11):
|
100 |
-
self.slice5.add_module(str(x), pretrained_features[x])
|
101 |
-
for x in range(11, 12):
|
102 |
-
self.slice6.add_module(str(x), pretrained_features[x])
|
103 |
-
for x in range(12, 13):
|
104 |
-
self.slice7.add_module(str(x), pretrained_features[x])
|
105 |
-
if not requires_grad:
|
106 |
-
for param in self.parameters():
|
107 |
-
param.requires_grad = False
|
108 |
-
|
109 |
-
def forward(self, X):
|
110 |
-
h = self.slice1(X)
|
111 |
-
h_relu1 = h
|
112 |
-
h = self.slice2(h)
|
113 |
-
h_relu2 = h
|
114 |
-
h = self.slice3(h)
|
115 |
-
h_relu3 = h
|
116 |
-
h = self.slice4(h)
|
117 |
-
h_relu4 = h
|
118 |
-
h = self.slice5(h)
|
119 |
-
h_relu5 = h
|
120 |
-
h = self.slice6(h)
|
121 |
-
h_relu6 = h
|
122 |
-
h = self.slice7(h)
|
123 |
-
h_relu7 = h
|
124 |
-
vgg_outputs = namedtuple(
|
125 |
-
"SqueezeOutputs",
|
126 |
-
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
|
127 |
-
)
|
128 |
-
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
129 |
-
|
130 |
-
return out
|
131 |
-
|
132 |
-
|
133 |
-
class alexnet(torch.nn.Module):
|
134 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
135 |
-
super(alexnet, self).__init__()
|
136 |
-
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
|
137 |
-
self.slice1 = torch.nn.Sequential()
|
138 |
-
self.slice2 = torch.nn.Sequential()
|
139 |
-
self.slice3 = torch.nn.Sequential()
|
140 |
-
self.slice4 = torch.nn.Sequential()
|
141 |
-
self.slice5 = torch.nn.Sequential()
|
142 |
-
self.N_slices = 5
|
143 |
-
for x in range(2):
|
144 |
-
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
145 |
-
for x in range(2, 5):
|
146 |
-
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
147 |
-
for x in range(5, 8):
|
148 |
-
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
149 |
-
for x in range(8, 10):
|
150 |
-
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
151 |
-
for x in range(10, 12):
|
152 |
-
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
153 |
-
if not requires_grad:
|
154 |
-
for param in self.parameters():
|
155 |
-
param.requires_grad = False
|
156 |
-
|
157 |
-
def forward(self, X):
|
158 |
-
h = self.slice1(X)
|
159 |
-
h_relu1 = h
|
160 |
-
h = self.slice2(h)
|
161 |
-
h_relu2 = h
|
162 |
-
h = self.slice3(h)
|
163 |
-
h_relu3 = h
|
164 |
-
h = self.slice4(h)
|
165 |
-
h_relu4 = h
|
166 |
-
h = self.slice5(h)
|
167 |
-
h_relu5 = h
|
168 |
-
alexnet_outputs = namedtuple(
|
169 |
-
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
|
170 |
-
)
|
171 |
-
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
172 |
-
|
173 |
-
return out
|
174 |
-
|
175 |
-
|
176 |
-
class PNetLin(nn.Module):
|
177 |
-
def __init__(
|
178 |
-
self,
|
179 |
-
pnet_type="vgg",
|
180 |
-
pnet_rand=False,
|
181 |
-
pnet_tune=False,
|
182 |
-
use_dropout=True,
|
183 |
-
use_gpu=True,
|
184 |
-
spatial=False,
|
185 |
-
version="0.1",
|
186 |
-
colorspace="RGB",
|
187 |
-
reduction="none",
|
188 |
-
):
|
189 |
-
super(PNetLin, self).__init__()
|
190 |
-
|
191 |
-
self.use_gpu = use_gpu
|
192 |
-
self.pnet_type = pnet_type
|
193 |
-
self.pnet_tune = pnet_tune
|
194 |
-
self.pnet_rand = pnet_rand
|
195 |
-
self.spatial = spatial
|
196 |
-
self.version = version
|
197 |
-
self.colorspace = colorspace
|
198 |
-
self.reduction = reduction
|
199 |
-
|
200 |
-
if self.pnet_type in ["vgg", "vgg16"]:
|
201 |
-
net_type = vgg16
|
202 |
-
self.chns = [64, 128, 256, 512, 512]
|
203 |
-
elif self.pnet_type == "alex":
|
204 |
-
net_type = alexnet
|
205 |
-
self.chns = [64, 192, 384, 256, 256]
|
206 |
-
elif self.pnet_type == "squeeze":
|
207 |
-
net_type = squeezenet
|
208 |
-
self.chns = [64, 128, 256, 384, 384, 512, 512]
|
209 |
-
|
210 |
-
if self.pnet_tune:
|
211 |
-
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=True)
|
212 |
-
else:
|
213 |
-
self.net = [
|
214 |
-
net_type(pretrained=not self.pnet_rand, requires_grad=False),
|
215 |
-
]
|
216 |
-
|
217 |
-
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
218 |
-
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
219 |
-
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
220 |
-
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
221 |
-
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
222 |
-
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
223 |
-
if self.pnet_type == "squeeze": # 7 layers for squeezenet
|
224 |
-
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
225 |
-
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
226 |
-
self.lins += [self.lin5, self.lin6]
|
227 |
-
|
228 |
-
self.shift = torch.autograd.Variable(
|
229 |
-
torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
|
230 |
-
)
|
231 |
-
self.scale = torch.autograd.Variable(
|
232 |
-
torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
|
233 |
-
)
|
234 |
-
|
235 |
-
if use_gpu:
|
236 |
-
if self.pnet_tune:
|
237 |
-
self.net.cuda()
|
238 |
-
else:
|
239 |
-
self.net[0].cuda()
|
240 |
-
self.shift = self.shift.cuda()
|
241 |
-
self.scale = self.scale.cuda()
|
242 |
-
self.lin0.cuda()
|
243 |
-
self.lin1.cuda()
|
244 |
-
self.lin2.cuda()
|
245 |
-
self.lin3.cuda()
|
246 |
-
self.lin4.cuda()
|
247 |
-
if self.pnet_type == "squeeze":
|
248 |
-
self.lin5.cuda()
|
249 |
-
self.lin6.cuda()
|
250 |
-
|
251 |
-
def forward(self, in0, in1):
|
252 |
-
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
253 |
-
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
254 |
-
|
255 |
-
if self.colorspace == "Gray":
|
256 |
-
in0_sc = util.tensor2tensorGrayscaleLazy(in0_sc)
|
257 |
-
in1_sc = util.tensor2tensorGrayscaleLazy(in1_sc)
|
258 |
-
|
259 |
-
if self.version == "0.0":
|
260 |
-
# v0.0 - original release had a bug, where input was not scaled
|
261 |
-
in0_input = in0
|
262 |
-
in1_input = in1
|
263 |
-
else:
|
264 |
-
# v0.1
|
265 |
-
in0_input = in0_sc
|
266 |
-
in1_input = in1_sc
|
267 |
-
|
268 |
-
if self.pnet_tune:
|
269 |
-
outs0 = self.net.forward(in0_input)
|
270 |
-
outs1 = self.net.forward(in1_input)
|
271 |
-
else:
|
272 |
-
outs0 = self.net[0].forward(in0_input)
|
273 |
-
outs1 = self.net[0].forward(in1_input)
|
274 |
-
|
275 |
-
feats0 = {}
|
276 |
-
feats1 = {}
|
277 |
-
diffs = [0] * len(outs0)
|
278 |
-
|
279 |
-
for kk, out0 in enumerate(outs0):
|
280 |
-
feats0[kk] = normalize_tensor(outs0[kk]) # norm NN outputs
|
281 |
-
feats1[kk] = normalize_tensor(outs1[kk])
|
282 |
-
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 # squared diff
|
283 |
-
|
284 |
-
if self.spatial:
|
285 |
-
lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
286 |
-
if self.pnet_type == "squeeze":
|
287 |
-
lin_models.extend([self.lin5, self.lin6])
|
288 |
-
res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))]
|
289 |
-
return res
|
290 |
-
|
291 |
-
val = torch.mean(
|
292 |
-
torch.mean(self.lin0.model(diffs[0]), dim=3), dim=2
|
293 |
-
) # sum means over H, W
|
294 |
-
val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]), dim=3), dim=2)
|
295 |
-
val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]), dim=3), dim=2)
|
296 |
-
val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]), dim=3), dim=2)
|
297 |
-
val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]), dim=3), dim=2)
|
298 |
-
if self.pnet_type == "squeeze":
|
299 |
-
val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]), dim=3), dim=2)
|
300 |
-
val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]), dim=3), dim=2)
|
301 |
-
|
302 |
-
val = val.view(val.size()[0], val.size()[1], 1, 1)
|
303 |
-
|
304 |
-
if self.reduction == "sum":
|
305 |
-
val = torch.sum(val)
|
306 |
-
|
307 |
-
return val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/loss_provider.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import os
|
4 |
-
import warnings
|
5 |
-
from .color_wrapper import ColorWrapper, GreyscaleWrapper
|
6 |
-
from .shift_wrapper import ShiftWrapper
|
7 |
-
from .watson import WatsonDistance
|
8 |
-
from .watson_fft import WatsonDistanceFft
|
9 |
-
from .watson_vgg import WatsonDistanceVgg
|
10 |
-
from .deep_loss import PNetLin
|
11 |
-
from .ssim import SSIM
|
12 |
-
|
13 |
-
|
14 |
-
class LossProvider:
|
15 |
-
def __init__(self):
|
16 |
-
self.loss_functions = [
|
17 |
-
"L1",
|
18 |
-
"L2",
|
19 |
-
"SSIM",
|
20 |
-
"Watson-dct",
|
21 |
-
"Watson-fft",
|
22 |
-
"Watson-vgg",
|
23 |
-
"Deeploss-vgg",
|
24 |
-
"Deeploss-squeeze",
|
25 |
-
"Adaptive",
|
26 |
-
]
|
27 |
-
self.color_models = ["LA", "RGB"]
|
28 |
-
|
29 |
-
def load_state_dict(self, filename):
|
30 |
-
current_dir = os.path.dirname(__file__)
|
31 |
-
path = os.path.join(current_dir, "weights", filename)
|
32 |
-
return torch.load(path, map_location="cpu")
|
33 |
-
|
34 |
-
def get_loss_function(
|
35 |
-
self,
|
36 |
-
model,
|
37 |
-
colorspace="RGB",
|
38 |
-
reduction="sum",
|
39 |
-
deterministic=False,
|
40 |
-
pretrained=True,
|
41 |
-
image_size=None,
|
42 |
-
):
|
43 |
-
"""
|
44 |
-
returns a trained loss class.
|
45 |
-
model: one of the values returned by self.loss_functions
|
46 |
-
colorspace: 'LA' or 'RGB'
|
47 |
-
deterministic: bool, if false (default) uses shifting of image blocks for watson-fft
|
48 |
-
image_size: tuple, size of input images. Only required for adaptive loss. Eg: [3, 64, 64]
|
49 |
-
"""
|
50 |
-
warnings.filterwarnings("ignore")
|
51 |
-
is_greyscale = colorspace in ["grey", "Grey", "LA", "greyscale", "grey-scale"]
|
52 |
-
|
53 |
-
if model.lower() in ["l2"]:
|
54 |
-
loss = nn.MSELoss(reduction=reduction)
|
55 |
-
elif model.lower() in ["l1"]:
|
56 |
-
loss = nn.L1Loss(reduction=reduction)
|
57 |
-
elif model.lower() in ["ssim"]:
|
58 |
-
loss = SSIM(size_average=(reduction in ["sum", "mean"]))
|
59 |
-
elif model.lower() in ["watson", "watson-dct"]:
|
60 |
-
if is_greyscale:
|
61 |
-
if deterministic:
|
62 |
-
loss = WatsonDistance(reduction=reduction)
|
63 |
-
if pretrained:
|
64 |
-
loss.load_state_dict(
|
65 |
-
self.load_state_dict("gray_watson_dct_trial0.pth")
|
66 |
-
)
|
67 |
-
else:
|
68 |
-
loss = ShiftWrapper(WatsonDistance, (), {"reduction": reduction})
|
69 |
-
if pretrained:
|
70 |
-
loss.loss.load_state_dict(
|
71 |
-
self.load_state_dict("gray_watson_dct_trial0.pth")
|
72 |
-
)
|
73 |
-
else:
|
74 |
-
if deterministic:
|
75 |
-
loss = ColorWrapper(WatsonDistance, (), {"reduction": reduction})
|
76 |
-
if pretrained:
|
77 |
-
loss.load_state_dict(
|
78 |
-
self.load_state_dict("rgb_watson_dct_trial0.pth")
|
79 |
-
)
|
80 |
-
else:
|
81 |
-
loss = ShiftWrapper(
|
82 |
-
ColorWrapper, (WatsonDistance, (), {"reduction": reduction}), {}
|
83 |
-
)
|
84 |
-
if pretrained:
|
85 |
-
loss.loss.load_state_dict(
|
86 |
-
self.load_state_dict("rgb_watson_dct_trial0.pth")
|
87 |
-
)
|
88 |
-
elif model.lower() in ["watson-fft", "watson-dft"]:
|
89 |
-
if is_greyscale:
|
90 |
-
if deterministic:
|
91 |
-
loss = WatsonDistanceFft(reduction=reduction)
|
92 |
-
if pretrained:
|
93 |
-
loss.load_state_dict(
|
94 |
-
self.load_state_dict("gray_watson_fft_trial0.pth")
|
95 |
-
)
|
96 |
-
else:
|
97 |
-
loss = ShiftWrapper(WatsonDistanceFft, (), {"reduction": reduction})
|
98 |
-
if pretrained:
|
99 |
-
loss.loss.load_state_dict(
|
100 |
-
self.load_state_dict("gray_watson_fft_trial0.pth")
|
101 |
-
)
|
102 |
-
else:
|
103 |
-
if deterministic:
|
104 |
-
loss = ColorWrapper(WatsonDistanceFft, (), {"reduction": reduction})
|
105 |
-
if pretrained:
|
106 |
-
loss.load_state_dict(
|
107 |
-
self.load_state_dict("rgb_watson_fft_trial0.pth")
|
108 |
-
)
|
109 |
-
else:
|
110 |
-
loss = ShiftWrapper(
|
111 |
-
ColorWrapper,
|
112 |
-
(WatsonDistanceFft, (), {"reduction": reduction}),
|
113 |
-
{},
|
114 |
-
)
|
115 |
-
if pretrained:
|
116 |
-
loss.loss.load_state_dict(
|
117 |
-
self.load_state_dict("rgb_watson_fft_trial0.pth")
|
118 |
-
)
|
119 |
-
elif model.lower() in ["watson-vgg", "watson-deep"]:
|
120 |
-
if is_greyscale:
|
121 |
-
loss = GreyscaleWrapper(WatsonDistanceVgg, (), {"reduction": reduction})
|
122 |
-
if pretrained:
|
123 |
-
loss.loss.load_state_dict(
|
124 |
-
self.load_state_dict("gray_watson_vgg_trial0.pth")
|
125 |
-
)
|
126 |
-
else:
|
127 |
-
loss = WatsonDistanceVgg(reduction=reduction)
|
128 |
-
if pretrained:
|
129 |
-
loss.load_state_dict(
|
130 |
-
self.load_state_dict("rgb_watson_vgg_trial0.pth")
|
131 |
-
)
|
132 |
-
elif model.lower() in ["deeploss-vgg"]:
|
133 |
-
if is_greyscale:
|
134 |
-
loss = GreyscaleWrapper(
|
135 |
-
PNetLin,
|
136 |
-
(),
|
137 |
-
{"pnet_type": "vgg", "reduction": reduction, "use_dropout": False},
|
138 |
-
)
|
139 |
-
if pretrained:
|
140 |
-
loss.loss.load_state_dict(
|
141 |
-
self.load_state_dict("gray_pnet_lin_vgg_trial0.pth")
|
142 |
-
)
|
143 |
-
else:
|
144 |
-
loss = PNetLin(pnet_type="vgg", reduction=reduction, use_dropout=False)
|
145 |
-
if pretrained:
|
146 |
-
loss.load_state_dict(
|
147 |
-
self.load_state_dict("rgb_pnet_lin_vgg_trial0.pth")
|
148 |
-
)
|
149 |
-
elif model.lower() in ["deeploss-squeeze"]:
|
150 |
-
if is_greyscale:
|
151 |
-
loss = GreyscaleWrapper(
|
152 |
-
PNetLin,
|
153 |
-
(),
|
154 |
-
{
|
155 |
-
"pnet_type": "squeeze",
|
156 |
-
"reduction": reduction,
|
157 |
-
"use_dropout": False,
|
158 |
-
},
|
159 |
-
)
|
160 |
-
if pretrained:
|
161 |
-
loss.loss.load_state_dict(
|
162 |
-
self.load_state_dict("gray_pnet_lin_squeeze_trial0.pth")
|
163 |
-
)
|
164 |
-
else:
|
165 |
-
loss = PNetLin(
|
166 |
-
pnet_type="squeeze", reduction=reduction, use_dropout=False
|
167 |
-
)
|
168 |
-
if pretrained:
|
169 |
-
loss.load_state_dict(
|
170 |
-
self.load_state_dict("rgb_pnet_lin_squeeze_trial0.pth")
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
raise Exception('Metric "{}" not implemented'.format(model))
|
174 |
-
|
175 |
-
# freeze all training of the loss functions
|
176 |
-
if pretrained:
|
177 |
-
for param in loss.parameters():
|
178 |
-
param.requires_grad = False
|
179 |
-
|
180 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/rfft2d.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.fft as fft
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
class Rfft2d(nn.Module):
|
8 |
-
"""
|
9 |
-
Blockwhise 2D FFT
|
10 |
-
for fixed blocksize of 8x8
|
11 |
-
"""
|
12 |
-
|
13 |
-
def __init__(self, blocksize=8, interleaving=False):
|
14 |
-
"""
|
15 |
-
Parameters:
|
16 |
-
"""
|
17 |
-
super().__init__() # call super constructor
|
18 |
-
|
19 |
-
self.blocksize = blocksize
|
20 |
-
self.interleaving = interleaving
|
21 |
-
if interleaving:
|
22 |
-
self.stride = self.blocksize // 2
|
23 |
-
else:
|
24 |
-
self.stride = self.blocksize
|
25 |
-
|
26 |
-
self.unfold = torch.nn.Unfold(
|
27 |
-
kernel_size=self.blocksize, padding=0, stride=self.stride
|
28 |
-
)
|
29 |
-
return
|
30 |
-
|
31 |
-
def forward(self, x):
|
32 |
-
"""
|
33 |
-
performs 2D blockwhise DCT
|
34 |
-
|
35 |
-
Parameters:
|
36 |
-
x: tensor of dimension (N, 1, h, w)
|
37 |
-
|
38 |
-
Return:
|
39 |
-
tensor of dimension (N, k, b, b/2, 2)
|
40 |
-
where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients.
|
41 |
-
The last dimension is pytorches representation of complex values
|
42 |
-
"""
|
43 |
-
|
44 |
-
(N, C, H, W) = x.shape
|
45 |
-
assert C == 1, "FFT is only implemented for a single channel"
|
46 |
-
assert H >= self.blocksize, "Input too small for blocksize"
|
47 |
-
assert W >= self.blocksize, "Input too small for blocksize"
|
48 |
-
assert (H % self.stride == 0) and (
|
49 |
-
W % self.stride == 0
|
50 |
-
), "FFT is only for dimensions divisible by the blocksize"
|
51 |
-
|
52 |
-
# unfold to blocks
|
53 |
-
x = self.unfold(x)
|
54 |
-
# now shape (N, 64, k)
|
55 |
-
(N, _, k) = x.shape
|
56 |
-
x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2)
|
57 |
-
# now shape (N, #k, b, b)
|
58 |
-
# perform DCT
|
59 |
-
coeff = fft.rfft(x)
|
60 |
-
coeff = torch.view_as_real(coeff)
|
61 |
-
|
62 |
-
return coeff / self.blocksize**2
|
63 |
-
|
64 |
-
def inverse(self, coeff, output_shape):
|
65 |
-
"""
|
66 |
-
performs 2D blockwhise inverse rFFT
|
67 |
-
|
68 |
-
Parameters:
|
69 |
-
output_shape: Tuple, dimensions of the outpus sample
|
70 |
-
"""
|
71 |
-
if self.interleaving:
|
72 |
-
raise Exception(
|
73 |
-
"Inverse block FFT is not implemented for interleaving blocks!"
|
74 |
-
)
|
75 |
-
|
76 |
-
# perform iRFFT
|
77 |
-
x = fft.irfft(coeff, dim=2, signal_sizes=(self.blocksize, self.blocksize))
|
78 |
-
(N, k, _, _) = x.shape
|
79 |
-
x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k)
|
80 |
-
x = F.fold(
|
81 |
-
x,
|
82 |
-
output_size=(output_shape[-2], output_shape[-1]),
|
83 |
-
kernel_size=self.blocksize,
|
84 |
-
padding=0,
|
85 |
-
stride=self.blocksize,
|
86 |
-
)
|
87 |
-
return x * (self.blocksize**2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/shift_wrapper.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
|
5 |
-
class ShiftWrapper(nn.Module):
|
6 |
-
"""
|
7 |
-
Extension for 2-dimensional inout loss functions.
|
8 |
-
Shifts the inputs by up to 4 pixels. Uses replication padding.
|
9 |
-
"""
|
10 |
-
|
11 |
-
def __init__(self, lossclass, args, kwargs):
|
12 |
-
"""
|
13 |
-
Parameters:
|
14 |
-
lossclass: class of the individual loss functions
|
15 |
-
trainable: bool, if True parameters of the loss are trained.
|
16 |
-
args: tuple, arguments for instantiation of loss fun
|
17 |
-
kwargs: dict, key word arguments for instantiation of loss fun
|
18 |
-
"""
|
19 |
-
super().__init__()
|
20 |
-
|
21 |
-
# submodules
|
22 |
-
self.add_module("loss", lossclass(*args, **kwargs))
|
23 |
-
|
24 |
-
# shift amount
|
25 |
-
self.max_shift = 8
|
26 |
-
|
27 |
-
# padding
|
28 |
-
self.pad = nn.ReplicationPad2d(self.max_shift // 2)
|
29 |
-
|
30 |
-
def forward(self, input, target):
|
31 |
-
# convert color space
|
32 |
-
input = self.pad(input)
|
33 |
-
target = self.pad(target)
|
34 |
-
|
35 |
-
shift_x = np.random.randint(self.max_shift)
|
36 |
-
shift_y = np.random.randint(self.max_shift)
|
37 |
-
|
38 |
-
input = input[
|
39 |
-
:,
|
40 |
-
:,
|
41 |
-
shift_x : -(self.max_shift - shift_x),
|
42 |
-
shift_y : -(self.max_shift - shift_y),
|
43 |
-
]
|
44 |
-
target = target[
|
45 |
-
:,
|
46 |
-
:,
|
47 |
-
shift_x : -(self.max_shift - shift_x),
|
48 |
-
shift_y : -(self.max_shift - shift_y),
|
49 |
-
]
|
50 |
-
|
51 |
-
return self.loss(input, target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/ssim.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
# SSIM implementation from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
2 |
-
import torch
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torch.autograd import Variable
|
5 |
-
from math import exp
|
6 |
-
|
7 |
-
|
8 |
-
def gaussian(window_size, sigma):
|
9 |
-
gauss = torch.Tensor(
|
10 |
-
[
|
11 |
-
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
|
12 |
-
for x in range(window_size)
|
13 |
-
]
|
14 |
-
)
|
15 |
-
return gauss / gauss.sum()
|
16 |
-
|
17 |
-
|
18 |
-
def create_window(window_size, channel):
|
19 |
-
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
20 |
-
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
21 |
-
window = Variable(
|
22 |
-
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
23 |
-
)
|
24 |
-
return window
|
25 |
-
|
26 |
-
|
27 |
-
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
28 |
-
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
29 |
-
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
30 |
-
|
31 |
-
mu1_sq = mu1.pow(2)
|
32 |
-
mu2_sq = mu2.pow(2)
|
33 |
-
mu1_mu2 = mu1 * mu2
|
34 |
-
|
35 |
-
sigma1_sq = (
|
36 |
-
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
37 |
-
)
|
38 |
-
sigma2_sq = (
|
39 |
-
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
40 |
-
)
|
41 |
-
sigma12 = (
|
42 |
-
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
|
43 |
-
- mu1_mu2
|
44 |
-
)
|
45 |
-
|
46 |
-
C1 = 0.01**2
|
47 |
-
C2 = 0.03**2
|
48 |
-
|
49 |
-
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
50 |
-
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
51 |
-
)
|
52 |
-
|
53 |
-
if size_average:
|
54 |
-
return ssim_map.mean()
|
55 |
-
else:
|
56 |
-
return ssim_map.mean(1).mean(1).mean(1)
|
57 |
-
|
58 |
-
|
59 |
-
class SSIM(torch.nn.Module):
|
60 |
-
def __init__(self, window_size=11, size_average=True):
|
61 |
-
super(SSIM, self).__init__()
|
62 |
-
self.window_size = window_size
|
63 |
-
self.size_average = size_average
|
64 |
-
self.channel = 1
|
65 |
-
self.window = create_window(window_size, self.channel)
|
66 |
-
|
67 |
-
def forward(self, img1, img2):
|
68 |
-
(_, channel, _, _) = img1.size()
|
69 |
-
|
70 |
-
if channel == self.channel and self.window.data.type() == img1.data.type():
|
71 |
-
window = self.window
|
72 |
-
else:
|
73 |
-
window = create_window(self.window_size, channel)
|
74 |
-
|
75 |
-
if img1.is_cuda:
|
76 |
-
window = window.cuda(img1.get_device())
|
77 |
-
window = window.type_as(img1)
|
78 |
-
|
79 |
-
self.window = window
|
80 |
-
self.channel = channel
|
81 |
-
|
82 |
-
return 1 - _ssim(
|
83 |
-
img1, img2, window, self.window_size, channel, self.size_average
|
84 |
-
)
|
85 |
-
|
86 |
-
|
87 |
-
def ssim(img1, img2, window_size=11, size_average=True):
|
88 |
-
(_, channel, _, _) = img1.size()
|
89 |
-
window = create_window(window_size, channel)
|
90 |
-
|
91 |
-
if img1.is_cuda:
|
92 |
-
window = window.cuda(img1.get_device())
|
93 |
-
window = window.type_as(img1)
|
94 |
-
|
95 |
-
return _ssim(img1, img2, window, window_size, channel, size_average)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/watson.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from .dct2d import Dct2d
|
5 |
-
|
6 |
-
EPS = 1e-10
|
7 |
-
|
8 |
-
|
9 |
-
def softmax(a, b, factor=1):
|
10 |
-
concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
|
11 |
-
softmax_factors = F.softmax(concat * factor, dim=-1)
|
12 |
-
return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
|
13 |
-
|
14 |
-
|
15 |
-
class WatsonDistance(nn.Module):
|
16 |
-
"""
|
17 |
-
Loss function based on Watsons perceptual distance.
|
18 |
-
Based on DCT quantization
|
19 |
-
"""
|
20 |
-
|
21 |
-
def __init__(self, blocksize=8, trainable=False, reduction="sum"):
|
22 |
-
"""
|
23 |
-
Parameters:
|
24 |
-
blocksize: int, size of the Blocks for discrete cosine transform
|
25 |
-
trainable: bool, if True parameters of the loss are trained and dropout is enabled.
|
26 |
-
reduction: 'sum' or 'none', determines return format
|
27 |
-
"""
|
28 |
-
super().__init__()
|
29 |
-
|
30 |
-
# input mapping
|
31 |
-
blocksize = torch.as_tensor(blocksize)
|
32 |
-
|
33 |
-
# module to perform 2D blockwise DCT
|
34 |
-
self.add_module("dct", Dct2d(blocksize=blocksize.item(), interleaving=False))
|
35 |
-
|
36 |
-
# parameters, initialized with values from watson paper
|
37 |
-
self.blocksize = nn.Parameter(blocksize, requires_grad=False)
|
38 |
-
if self.blocksize == 8:
|
39 |
-
# init with Jpeg QM
|
40 |
-
self.t_tild = nn.Parameter(
|
41 |
-
torch.log(
|
42 |
-
torch.tensor( # log-scaled weights
|
43 |
-
[
|
44 |
-
[1.40, 1.01, 1.16, 1.66, 2.40, 3.43, 4.79, 6.56],
|
45 |
-
[1.01, 1.45, 1.32, 1.52, 2.00, 2.71, 3.67, 4.93],
|
46 |
-
[1.16, 1.32, 2.24, 2.59, 2.98, 3.64, 4.60, 5.88],
|
47 |
-
[1.66, 1.52, 2.59, 3.77, 4.55, 5.30, 6.28, 7.60],
|
48 |
-
[2.40, 2.00, 2.98, 4.55, 6.15, 7.46, 8.71, 10.17],
|
49 |
-
[3.43, 2.71, 3.64, 5.30, 7.46, 9.62, 11.58, 13.51],
|
50 |
-
[4.79, 3.67, 4.60, 6.28, 8.71, 11.58, 14.50, 17.29],
|
51 |
-
[6.56, 4.93, 5.88, 7.60, 10.17, 13.51, 17.29, 21.15],
|
52 |
-
]
|
53 |
-
)
|
54 |
-
),
|
55 |
-
requires_grad=trainable,
|
56 |
-
)
|
57 |
-
else:
|
58 |
-
# init with uniform QM
|
59 |
-
self.t_tild = nn.Parameter(
|
60 |
-
torch.zeros((self.blocksize, self.blocksize)), requires_grad=trainable
|
61 |
-
)
|
62 |
-
|
63 |
-
# other default parameters
|
64 |
-
self.alpha = nn.Parameter(
|
65 |
-
torch.tensor(0.649), requires_grad=trainable
|
66 |
-
) # luminance masking
|
67 |
-
w = torch.tensor(0.7) # contrast masking
|
68 |
-
self.w_tild = nn.Parameter(
|
69 |
-
torch.log(w / (1 - w)), requires_grad=trainable
|
70 |
-
) # inverse of sigmoid
|
71 |
-
self.beta = nn.Parameter(torch.tensor(4.0), requires_grad=trainable) # pooling
|
72 |
-
|
73 |
-
# dropout for training
|
74 |
-
self.dropout = nn.Dropout(0.5 if trainable else 0)
|
75 |
-
|
76 |
-
# reduction
|
77 |
-
self.reduction = reduction
|
78 |
-
if reduction not in ["sum", "none"]:
|
79 |
-
raise Exception(
|
80 |
-
'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
|
81 |
-
reduction
|
82 |
-
)
|
83 |
-
)
|
84 |
-
|
85 |
-
@property
|
86 |
-
def t(self):
|
87 |
-
# returns QM
|
88 |
-
qm = torch.exp(self.t_tild)
|
89 |
-
return qm
|
90 |
-
|
91 |
-
@property
|
92 |
-
def w(self):
|
93 |
-
# return luminance masking parameter
|
94 |
-
return torch.sigmoid(self.w_tild)
|
95 |
-
|
96 |
-
def forward(self, input, target):
|
97 |
-
# dct
|
98 |
-
c0 = self.dct(target)
|
99 |
-
c1 = self.dct(input)
|
100 |
-
|
101 |
-
N, K, B, B = c0.shape
|
102 |
-
|
103 |
-
# luminance masking
|
104 |
-
avg_lum = torch.mean(c0[:, :, 0, 0])
|
105 |
-
t_l = self.t.view(1, 1, B, B).expand(N, K, B, B)
|
106 |
-
t_l = t_l * (((c0[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(
|
107 |
-
N, K, 1, 1
|
108 |
-
)
|
109 |
-
|
110 |
-
# contrast masking
|
111 |
-
s = softmax(t_l, (c0.abs() + EPS) ** self.w * t_l ** (1 - self.w))
|
112 |
-
|
113 |
-
# pooling
|
114 |
-
watson_dist = (((c0 - c1) / s).abs() + EPS) ** self.beta
|
115 |
-
watson_dist = self.dropout(watson_dist) + EPS
|
116 |
-
watson_dist = torch.sum(watson_dist, dim=(1, 2, 3))
|
117 |
-
watson_dist = watson_dist ** (1 / self.beta)
|
118 |
-
|
119 |
-
# reduction
|
120 |
-
if self.reduction == "sum":
|
121 |
-
watson_dist = torch.sum(watson_dist)
|
122 |
-
|
123 |
-
return watson_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/watson_fft.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from .rfft2d import Rfft2d
|
5 |
-
|
6 |
-
EPS = 1e-10
|
7 |
-
|
8 |
-
|
9 |
-
def softmax(a, b, factor=1):
|
10 |
-
concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
|
11 |
-
softmax_factors = F.softmax(concat * factor, dim=-1)
|
12 |
-
return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
|
13 |
-
|
14 |
-
|
15 |
-
class WatsonDistanceFft(nn.Module):
|
16 |
-
"""
|
17 |
-
Loss function based on Watsons perceptual distance.
|
18 |
-
Based on FFT quantization
|
19 |
-
"""
|
20 |
-
|
21 |
-
def __init__(self, blocksize=8, trainable=False, reduction="sum"):
|
22 |
-
"""
|
23 |
-
Parameters:
|
24 |
-
blocksize: int, size of the Blocks for discrete cosine transform
|
25 |
-
trainable: bool, if True parameters of the loss are trained and dropout is enabled.
|
26 |
-
reduction: 'sum' or 'none', determines return format
|
27 |
-
"""
|
28 |
-
super().__init__()
|
29 |
-
self.trainable = trainable
|
30 |
-
|
31 |
-
# input mapping
|
32 |
-
blocksize = torch.as_tensor(blocksize)
|
33 |
-
|
34 |
-
# module to perform 2D blockwise rFFT
|
35 |
-
self.add_module("fft", Rfft2d(blocksize=blocksize.item(), interleaving=False))
|
36 |
-
|
37 |
-
# parameters
|
38 |
-
self.weight_size = (blocksize, blocksize // 2 + 1)
|
39 |
-
self.blocksize = nn.Parameter(blocksize, requires_grad=False)
|
40 |
-
# init with uniform QM
|
41 |
-
self.t_tild = nn.Parameter(
|
42 |
-
torch.zeros(self.weight_size), requires_grad=trainable
|
43 |
-
)
|
44 |
-
self.alpha = nn.Parameter(
|
45 |
-
torch.tensor(0.1), requires_grad=trainable
|
46 |
-
) # luminance masking
|
47 |
-
w = torch.tensor(0.2) # contrast masking
|
48 |
-
self.w_tild = nn.Parameter(
|
49 |
-
torch.log(w / (1 - w)), requires_grad=trainable
|
50 |
-
) # inverse of sigmoid
|
51 |
-
self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling
|
52 |
-
|
53 |
-
# phase weights
|
54 |
-
self.w_phase_tild = nn.Parameter(
|
55 |
-
torch.zeros(self.weight_size) - 2.0, requires_grad=trainable
|
56 |
-
)
|
57 |
-
|
58 |
-
# dropout for training
|
59 |
-
self.dropout = nn.Dropout(0.5 if trainable else 0)
|
60 |
-
|
61 |
-
# reduction
|
62 |
-
self.reduction = reduction
|
63 |
-
if reduction not in ["sum", "none"]:
|
64 |
-
raise Exception(
|
65 |
-
'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
|
66 |
-
reduction
|
67 |
-
)
|
68 |
-
)
|
69 |
-
|
70 |
-
@property
|
71 |
-
def t(self):
|
72 |
-
# returns QM
|
73 |
-
qm = torch.exp(self.t_tild)
|
74 |
-
return qm
|
75 |
-
|
76 |
-
@property
|
77 |
-
def w(self):
|
78 |
-
# return luminance masking parameter
|
79 |
-
return torch.sigmoid(self.w_tild)
|
80 |
-
|
81 |
-
@property
|
82 |
-
def w_phase(self):
|
83 |
-
# return weights for phase
|
84 |
-
w_phase = torch.exp(self.w_phase_tild)
|
85 |
-
# set weights of non-phases to 0
|
86 |
-
if not self.trainable:
|
87 |
-
w_phase[0, 0] = 0.0
|
88 |
-
w_phase[0, self.weight_size[1] - 1] = 0.0
|
89 |
-
w_phase[self.weight_size[1] - 1, self.weight_size[1] - 1] = 0.0
|
90 |
-
w_phase[self.weight_size[1] - 1, 0] = 0.0
|
91 |
-
return w_phase
|
92 |
-
|
93 |
-
def forward(self, input, target):
|
94 |
-
# fft
|
95 |
-
c0 = self.fft(target)
|
96 |
-
c1 = self.fft(input)
|
97 |
-
|
98 |
-
N, K, H, W, _ = c0.shape
|
99 |
-
|
100 |
-
# get amplitudes
|
101 |
-
c0_amp = torch.norm(c0 + EPS, p="fro", dim=4)
|
102 |
-
c1_amp = torch.norm(c1 + EPS, p="fro", dim=4)
|
103 |
-
|
104 |
-
# luminance masking
|
105 |
-
avg_lum = torch.mean(c0_amp[:, :, 0, 0])
|
106 |
-
t_l = self.t.view(1, 1, H, W).expand(N, K, H, W)
|
107 |
-
t_l = t_l * (((c0_amp[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(
|
108 |
-
N, K, 1, 1
|
109 |
-
)
|
110 |
-
|
111 |
-
# contrast masking
|
112 |
-
s = softmax(t_l, (c0_amp.abs() + EPS) ** self.w * t_l ** (1 - self.w))
|
113 |
-
|
114 |
-
# pooling
|
115 |
-
watson_dist = (((c0_amp - c1_amp) / s).abs() + EPS) ** self.beta
|
116 |
-
watson_dist = self.dropout(watson_dist) + EPS
|
117 |
-
watson_dist = torch.sum(watson_dist, dim=(1, 2, 3))
|
118 |
-
watson_dist = watson_dist ** (1 / self.beta)
|
119 |
-
|
120 |
-
# get phases
|
121 |
-
c0_phase = torch.atan2(c0[:, :, :, :, 1], c0[:, :, :, :, 0] + EPS)
|
122 |
-
c1_phase = torch.atan2(c1[:, :, :, :, 1], c1[:, :, :, :, 0] + EPS)
|
123 |
-
|
124 |
-
# angular distance
|
125 |
-
phase_dist = (
|
126 |
-
torch.acos(torch.cos(c0_phase - c1_phase) * (1 - EPS * 10**3))
|
127 |
-
* self.w_phase
|
128 |
-
) # we multiply with a factor ->1 to prevent taking the gradient of acos(-1) or acos(1). The gradient in this case would be -/+ inf
|
129 |
-
phase_dist = self.dropout(phase_dist)
|
130 |
-
phase_dist = torch.sum(phase_dist, dim=(1, 2, 3))
|
131 |
-
|
132 |
-
# perceptual distance
|
133 |
-
distance = watson_dist + phase_dist
|
134 |
-
|
135 |
-
# reduce
|
136 |
-
if self.reduction == "sum":
|
137 |
-
distance = torch.sum(distance)
|
138 |
-
|
139 |
-
return distance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/metrics/watson/watson_vgg.py
DELETED
@@ -1,202 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torchvision
|
5 |
-
|
6 |
-
EPS = 1e-10
|
7 |
-
|
8 |
-
|
9 |
-
class VggFeatureExtractor(nn.Module):
|
10 |
-
def __init__(self):
|
11 |
-
super(VggFeatureExtractor, self).__init__()
|
12 |
-
|
13 |
-
# download vgg
|
14 |
-
vgg16 = torchvision.models.vgg16(pretrained=True).features
|
15 |
-
|
16 |
-
# set non trainable
|
17 |
-
for param in vgg16.parameters():
|
18 |
-
param.requires_grad = False
|
19 |
-
|
20 |
-
# slice model
|
21 |
-
self.slice1 = torch.nn.Sequential()
|
22 |
-
self.slice2 = torch.nn.Sequential()
|
23 |
-
self.slice3 = torch.nn.Sequential()
|
24 |
-
self.slice4 = torch.nn.Sequential()
|
25 |
-
self.slice5 = torch.nn.Sequential()
|
26 |
-
|
27 |
-
for x in range(4): # conv relu conv relu
|
28 |
-
self.slice1.add_module(str(x), vgg16[x])
|
29 |
-
for x in range(4, 9): # max conv relu conv relu
|
30 |
-
self.slice2.add_module(str(x), vgg16[x])
|
31 |
-
for x in range(9, 16): # max cov relu conv relu conv relu
|
32 |
-
self.slice3.add_module(str(x), vgg16[x])
|
33 |
-
for x in range(16, 23): # conv relu max conv relu conv relu
|
34 |
-
self.slice4.add_module(str(x), vgg16[x])
|
35 |
-
for x in range(23, 30): # conv relu conv relu max conv relu
|
36 |
-
self.slice5.add_module(str(x), vgg16[x])
|
37 |
-
|
38 |
-
def forward(self, X):
|
39 |
-
h = self.slice1(X)
|
40 |
-
h_relu1_2 = h
|
41 |
-
h = self.slice2(h)
|
42 |
-
h_relu2_2 = h
|
43 |
-
h = self.slice3(h)
|
44 |
-
h_relu3_3 = h
|
45 |
-
h = self.slice4(h)
|
46 |
-
h_relu4_3 = h
|
47 |
-
h = self.slice5(h)
|
48 |
-
h_relu5_3 = h
|
49 |
-
|
50 |
-
return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
51 |
-
|
52 |
-
|
53 |
-
def normalize_tensor(t):
|
54 |
-
# norms a tensor over the channel dimension to an euclidean length of 1.
|
55 |
-
N, C, H, W = t.shape
|
56 |
-
norm_factor = torch.sqrt(torch.sum(t**2, dim=1)).view(N, 1, H, W)
|
57 |
-
return t / (norm_factor.expand_as(t) + EPS)
|
58 |
-
|
59 |
-
|
60 |
-
def softmax(a, b, factor=1):
|
61 |
-
concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
|
62 |
-
softmax_factors = F.softmax(concat * factor, dim=-1)
|
63 |
-
return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
|
64 |
-
|
65 |
-
|
66 |
-
class WatsonDistanceVgg(nn.Module):
|
67 |
-
"""
|
68 |
-
Loss function based on Watsons perceptual distance.
|
69 |
-
Based on deep feature extraction
|
70 |
-
"""
|
71 |
-
|
72 |
-
def __init__(self, trainable=False, reduction="sum"):
|
73 |
-
"""
|
74 |
-
Parameters:
|
75 |
-
trainable: bool, if True parameters of the loss are trained and dropout is enabled.
|
76 |
-
reduction: 'sum' or 'none', determines return format
|
77 |
-
"""
|
78 |
-
super().__init__()
|
79 |
-
|
80 |
-
# module to perform feature extraction
|
81 |
-
self.add_module("vgg", VggFeatureExtractor())
|
82 |
-
|
83 |
-
# imagenet-normalization
|
84 |
-
self.shift = nn.Parameter(
|
85 |
-
torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1), requires_grad=False
|
86 |
-
)
|
87 |
-
self.scale = nn.Parameter(
|
88 |
-
torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1), requires_grad=False
|
89 |
-
)
|
90 |
-
|
91 |
-
# channel dimensions
|
92 |
-
self.L = 5
|
93 |
-
self.channels = [64, 128, 256, 512, 512]
|
94 |
-
|
95 |
-
# sensitivity parameters
|
96 |
-
self.t0_tild = nn.Parameter(
|
97 |
-
torch.zeros((self.channels[0])), requires_grad=trainable
|
98 |
-
)
|
99 |
-
self.t1_tild = nn.Parameter(
|
100 |
-
torch.zeros((self.channels[1])), requires_grad=trainable
|
101 |
-
)
|
102 |
-
self.t2_tild = nn.Parameter(
|
103 |
-
torch.zeros((self.channels[2])), requires_grad=trainable
|
104 |
-
)
|
105 |
-
self.t3_tild = nn.Parameter(
|
106 |
-
torch.zeros((self.channels[3])), requires_grad=trainable
|
107 |
-
)
|
108 |
-
self.t4_tild = nn.Parameter(
|
109 |
-
torch.zeros((self.channels[4])), requires_grad=trainable
|
110 |
-
)
|
111 |
-
|
112 |
-
# other default parameters
|
113 |
-
w = torch.tensor(0.2) # contrast masking
|
114 |
-
self.w0_tild = nn.Parameter(
|
115 |
-
torch.log(w / (1 - w)), requires_grad=trainable
|
116 |
-
) # inverse of sigmoid
|
117 |
-
self.w1_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
|
118 |
-
self.w2_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
|
119 |
-
self.w3_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
|
120 |
-
self.w4_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
|
121 |
-
self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling
|
122 |
-
|
123 |
-
# dropout for training
|
124 |
-
self.dropout = nn.Dropout(0.5 if trainable else 0)
|
125 |
-
|
126 |
-
# reduction
|
127 |
-
self.reduction = reduction
|
128 |
-
if reduction not in ["sum", "none"]:
|
129 |
-
raise Exception(
|
130 |
-
'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
|
131 |
-
reduction
|
132 |
-
)
|
133 |
-
)
|
134 |
-
|
135 |
-
@property
|
136 |
-
def t(self):
|
137 |
-
return [
|
138 |
-
torch.exp(t)
|
139 |
-
for t in [
|
140 |
-
self.t0_tild,
|
141 |
-
self.t1_tild,
|
142 |
-
self.t2_tild,
|
143 |
-
self.t3_tild,
|
144 |
-
self.t4_tild,
|
145 |
-
]
|
146 |
-
]
|
147 |
-
|
148 |
-
@property
|
149 |
-
def w(self):
|
150 |
-
# return luminance masking parameter
|
151 |
-
return [
|
152 |
-
torch.sigmoid(w)
|
153 |
-
for w in [
|
154 |
-
self.w0_tild,
|
155 |
-
self.w1_tild,
|
156 |
-
self.w2_tild,
|
157 |
-
self.w3_tild,
|
158 |
-
self.w4_tild,
|
159 |
-
]
|
160 |
-
]
|
161 |
-
|
162 |
-
def forward(self, input, target):
|
163 |
-
# normalization
|
164 |
-
input = (input - self.shift.expand_as(input)) / self.scale.expand_as(input)
|
165 |
-
target = (target - self.shift.expand_as(target)) / self.scale.expand_as(target)
|
166 |
-
|
167 |
-
# feature extraction
|
168 |
-
c0 = self.vgg(target)
|
169 |
-
c1 = self.vgg(input)
|
170 |
-
|
171 |
-
# norm over channels
|
172 |
-
for l in range(self.L):
|
173 |
-
c0[l] = normalize_tensor(c0[l])
|
174 |
-
c1[l] = normalize_tensor(c1[l])
|
175 |
-
|
176 |
-
# contrast masking
|
177 |
-
t = self.t
|
178 |
-
w = self.w
|
179 |
-
s = []
|
180 |
-
for l in range(self.L):
|
181 |
-
N, C_l, H_l, W_l = c0[l].shape
|
182 |
-
t_l = t[l].view(1, C_l, 1, 1).expand(N, C_l, H_l, W_l)
|
183 |
-
s.append(softmax(t_l, (c0[l].abs() + EPS) ** w[l] * t_l ** (1 - w[l])))
|
184 |
-
|
185 |
-
# pooling
|
186 |
-
watson_dist = 0
|
187 |
-
for l in range(self.L):
|
188 |
-
_, _, H_l, W_l = c0[l].shape
|
189 |
-
layer_dist = (((c0[l] - c1[l]) / s[l]).abs() + EPS) ** self.beta
|
190 |
-
layer_dist = self.dropout(layer_dist) + EPS
|
191 |
-
layer_dist = torch.sum(
|
192 |
-
layer_dist, dim=(1, 2, 3)
|
193 |
-
) # sum over dimensions of layer
|
194 |
-
layer_dist = (1 / (H_l * W_l)) * layer_dist # normalize by layer size
|
195 |
-
watson_dist += layer_dist # sum over layers
|
196 |
-
watson_dist = watson_dist ** (1 / self.beta)
|
197 |
-
|
198 |
-
# reduction
|
199 |
-
if self.reduction == "sum":
|
200 |
-
watson_dist = torch.sum(watson_dist)
|
201 |
-
|
202 |
-
return watson_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kit/models/stable_signature.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b58841ab09f23e89acf5aedade09c7f65908ae33437c5242ad987d99b5cd2c1
|
3 |
+
size 1228161
|
requirements.txt
CHANGED
@@ -5,6 +5,7 @@ torchvision
|
|
5 |
transformers
|
6 |
open_clip_torch
|
7 |
numpy
|
|
|
8 |
Pillow
|
9 |
redis
|
10 |
plotly
|
|
|
5 |
transformers
|
6 |
open_clip_torch
|
7 |
numpy
|
8 |
+
scipy
|
9 |
Pillow
|
10 |
redis
|
11 |
plotly
|