mcding commited on
Commit
23c4bb7
1 Parent(s): 32afe8e

fix lfs issue

Browse files
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import os
2
  import gradio as gr
3
-
4
  import numpy as np
5
  import json
6
  import redis
7
- from PIL import Image
8
- import time
9
  import plotly.graph_objects as go
10
  from datetime import datetime
11
  from kit import compute_performance, compute_quality
@@ -13,6 +10,35 @@ import dotenv
13
 
14
  dotenv.load_dotenv()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Connect to Redis
18
  redis_client = redis.Redis(
@@ -39,7 +65,10 @@ def get_submissions_from_redis():
39
  return [json.loads(submission) for submission in submissions]
40
 
41
 
42
- def update_leaderboard(submissions):
 
 
 
43
  names = [sub["name"] for sub in submissions]
44
  performances = [float(sub["performance"]) for sub in submissions]
45
  qualities = [float(sub["quality"]) for sub in submissions]
@@ -47,16 +76,26 @@ def update_leaderboard(submissions):
47
  # Create scatter plot
48
  fig = go.Figure()
49
 
50
- fig.add_trace(
51
- go.Scatter(
52
- x=qualities,
53
- y=performances,
54
- mode="markers+text",
55
- text=names,
56
- textposition="top center",
57
- name="Submissions",
 
 
 
 
 
 
 
 
 
 
 
58
  )
59
- )
60
 
61
  # Add circles
62
  circle_radii = np.linspace(0, 1, 5)
@@ -76,13 +115,17 @@ def update_leaderboard(submissions):
76
 
77
  # Update layout
78
  fig.update_layout(
79
- title="Submissions Leaderboard",
80
- xaxis_title="Quality",
81
- yaxis_title="Performance",
82
- xaxis=dict(range=[0, 1]),
83
- yaxis=dict(range=[0, 1]),
84
- width=600,
85
- height=600,
 
 
 
 
86
  )
87
 
88
  return fig
@@ -91,17 +134,15 @@ def update_leaderboard(submissions):
91
  def process_submission(name, image):
92
  original_image = Image.open("./image.png")
93
  progress = gr.Progress()
94
- progress(0, desc="Processing")
95
- time.sleep(0.5)
96
- progress(0.1, desc="Decoding")
97
  performance = compute_performance(image)
98
- progress(0.6, desc="Computing metric")
99
  quality = compute_quality(image, original_image)
100
- progress(0.9, desc="Saving results")
101
  save_to_redis(name, performance, quality)
102
 
103
  submissions = get_submissions_from_redis()
104
- leaderboard_plot = update_leaderboard(submissions)
105
 
106
  # Calculate rank
107
  distances = [
@@ -111,13 +152,15 @@ def process_submission(name, image):
111
  rank = (
112
  sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1
113
  )
114
-
115
- progress(1.0, desc="Complete")
116
- return leaderboard_plot, rank, name, performance, quality
117
-
118
-
119
- def download_image():
120
- return "./image.png"
 
 
121
 
122
 
123
  def upload_and_evaluate(name, image):
@@ -129,61 +172,105 @@ def upload_and_evaluate(name, image):
129
 
130
 
131
  def create_interface():
132
- with gr.Blocks() as demo:
133
  gr.Markdown(
134
  """
135
- # Erasing the Invisible -- NeurIPS24 Watermark Removal Challenge Demo
 
136
  """
137
  )
138
 
139
- with gr.Tabs() as tabs:
140
  with gr.Tab("Original Watermarked Image", id="download"):
 
 
 
 
 
141
  with gr.Column():
142
  original_image = gr.Image(
143
  value="./image.png",
 
144
  label="Original Watermarked Image",
145
  show_label=True,
146
  height=512,
 
 
147
  show_download_button=False,
148
  show_share_button=False,
149
  show_fullscreen_button=False,
 
 
150
  )
151
  with gr.Row():
152
- gr.DownloadButton(
153
- "Download Watermarked Image", value="./image.png", scale=3
 
 
 
 
 
154
  )
155
- submit_btn = gr.Button("Submit Your Removal", scale=3)
156
 
157
- with gr.Tab("Submit Watermark Removed Image", id="submit"):
 
 
 
 
 
 
 
 
 
158
  with gr.Column():
159
  uploaded_image = gr.Image(
160
  label="Your Watermark Removed Image",
 
161
  show_label=True,
162
  height=512,
 
 
163
  type="pil",
164
  show_download_button=False,
165
  show_share_button=False,
166
  show_fullscreen_button=False,
 
 
 
167
  )
168
  with gr.Row():
169
  name_input = gr.Textbox(
170
- label="Your Name",
171
  )
172
  upload_btn = gr.Button("Upload and Evaluate")
173
 
174
- with gr.Tab("Evaluation Results and Your Ranking", id="leaderboard"):
 
 
 
 
 
 
 
 
 
175
  with gr.Column():
176
  leaderboard_plot = gr.Plot(
177
- label="Evalution Results", show_label=True
 
 
178
  )
179
  with gr.Row():
180
- rank_output = gr.Number(label="Your Ranking")
181
  name_output = gr.Textbox(label="Your Name")
182
- performance_output = gr.Number(
183
- label="Watermark Performance Score (lower is better)"
 
 
 
184
  )
185
- quality_output = gr.Number(
186
- label="Quality Degredation Score (lower is better)"
187
  )
188
 
189
  submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
@@ -197,13 +284,14 @@ def create_interface():
197
  name_output,
198
  performance_output,
199
  quality_output,
 
200
  ],
201
  )
202
 
203
  demo.load(
204
  lambda: [
205
  gr.Image(value="./image.png", height=512, width=512),
206
- gr.Plot(update_leaderboard(get_submissions_from_redis())),
207
  ],
208
  outputs=[original_image, leaderboard_plot],
209
  )
 
1
  import os
2
  import gradio as gr
 
3
  import numpy as np
4
  import json
5
  import redis
 
 
6
  import plotly.graph_objects as go
7
  from datetime import datetime
8
  from kit import compute_performance, compute_quality
 
10
 
11
  dotenv.load_dotenv()
12
 
13
+ CSS = """
14
+ .tabs button{
15
+ font-size: 24px;
16
+ }
17
+ #download_btn {
18
+ height: 91.6px;
19
+ }
20
+ #submit_btn {
21
+ height: 91.6px;
22
+ }
23
+ #original_image {
24
+ display: block;
25
+ margin-left: auto;
26
+ margin-right: auto;
27
+ }
28
+ #uploaded_image {
29
+ display: block;
30
+ margin-left: auto;
31
+ margin-right: auto;
32
+ }
33
+ #leaderboard_plot {
34
+ display: block;
35
+ margin-left: auto;
36
+ margin-right: auto;
37
+ width: 512px; /* Adjust width as needed */
38
+ height: 512px; /* Adjust height as needed */
39
+ }
40
+ """
41
+
42
 
43
  # Connect to Redis
44
  redis_client = redis.Redis(
 
65
  return [json.loads(submission) for submission in submissions]
66
 
67
 
68
+ def update_plot(
69
+ submissions,
70
+ current_name=None,
71
+ ):
72
  names = [sub["name"] for sub in submissions]
73
  performances = [float(sub["performance"]) for sub in submissions]
74
  qualities = [float(sub["quality"]) for sub in submissions]
 
76
  # Create scatter plot
77
  fig = go.Figure()
78
 
79
+ for name, quality, performance in zip(names, qualities, performances):
80
+ if name == current_name:
81
+ marker = dict(symbol="star", size=15, color="blue")
82
+ elif name.startswith("Baseline: "):
83
+ marker = dict(symbol="square", size=10, color="grey")
84
+ else:
85
+ marker = dict(symbol="circle", size=10, color="green")
86
+
87
+ fig.add_trace(
88
+ go.Scatter(
89
+ x=[quality],
90
+ y=[performance],
91
+ mode="markers+text",
92
+ text=[name],
93
+ textposition="top center",
94
+ name=name,
95
+ marker=marker,
96
+ hovertemplate=f"{'Name: ' + name if not name.startswith('Baseline: ') else name}<br>(Performance, Quality) = ({performance:.3f}, {quality:.3f})",
97
+ )
98
  )
 
99
 
100
  # Add circles
101
  circle_radii = np.linspace(0, 1, 5)
 
115
 
116
  # Update layout
117
  fig.update_layout(
118
+ xaxis_title="Image Quality Degredation",
119
+ yaxis_title="Watermark Detection Performance",
120
+ xaxis=dict(
121
+ range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
122
+ ),
123
+ yaxis=dict(
124
+ range=[0, 1.1], titlefont=dict(size=16) # Adjust this value as needed
125
+ ),
126
+ width=512,
127
+ height=512,
128
+ showlegend=False, # Remove legend
129
  )
130
 
131
  return fig
 
134
  def process_submission(name, image):
135
  original_image = Image.open("./image.png")
136
  progress = gr.Progress()
137
+ progress(0, desc="Detecting Watermark")
 
 
138
  performance = compute_performance(image)
139
+ progress(0.4, desc="Evaluating Image Quality")
140
  quality = compute_quality(image, original_image)
141
+ progress(1.0, desc="Uploading Results")
142
  save_to_redis(name, performance, quality)
143
 
144
  submissions = get_submissions_from_redis()
145
+ leaderboard_plot = update_plot(submissions, current_name=name)
146
 
147
  # Calculate rank
148
  distances = [
 
152
  rank = (
153
  sorted(distances, reverse=True).index(np.sqrt(quality**2 + performance**2)) + 1
154
  )
155
+ gr.Info(f"You ranked {rank} out of {len(submissions)}!")
156
+ return (
157
+ leaderboard_plot,
158
+ f"{rank} out of {len(submissions)}",
159
+ name,
160
+ f"{performance:.3f}",
161
+ f"{quality:.3f}",
162
+ f"{np.sqrt(quality**2 + performance**2):.3f}",
163
+ )
164
 
165
 
166
  def upload_and_evaluate(name, image):
 
172
 
173
 
174
  def create_interface():
175
+ with gr.Blocks(css=CSS) as demo:
176
  gr.Markdown(
177
  """
178
+ # Erasing the Invisible Demo
179
+ TODO: Improve title and add description, add icon.jpg, also improve configs in README.md
180
  """
181
  )
182
 
183
+ with gr.Tabs(elem_classes=["tabs"]) as tabs:
184
  with gr.Tab("Original Watermarked Image", id="download"):
185
+ gr.Markdown(
186
+ """
187
+ TODO: Add descriptions
188
+ """
189
+ )
190
  with gr.Column():
191
  original_image = gr.Image(
192
  value="./image.png",
193
+ format="png",
194
  label="Original Watermarked Image",
195
  show_label=True,
196
  height=512,
197
+ width=512,
198
+ type="filepath",
199
  show_download_button=False,
200
  show_share_button=False,
201
  show_fullscreen_button=False,
202
+ container=True,
203
+ elem_id="original_image",
204
  )
205
  with gr.Row():
206
+ download_btn = gr.DownloadButton(
207
+ "Download Watermarked Image",
208
+ value="./image.png",
209
+ elem_id="download_btn",
210
+ )
211
+ submit_btn = gr.Button(
212
+ "Submit Your Removal", elem_id="submit_btn"
213
  )
 
214
 
215
+ with gr.Tab(
216
+ "Submit Watermark Removed Image",
217
+ id="submit",
218
+ elem_classes="gr-tab-header",
219
+ ):
220
+ gr.Markdown(
221
+ """
222
+ TODO: Add descriptions
223
+ """
224
+ )
225
  with gr.Column():
226
  uploaded_image = gr.Image(
227
  label="Your Watermark Removed Image",
228
+ format="png",
229
  show_label=True,
230
  height=512,
231
+ width=512,
232
+ sources=["upload"],
233
  type="pil",
234
  show_download_button=False,
235
  show_share_button=False,
236
  show_fullscreen_button=False,
237
+ container=True,
238
+ placeholder="Upload your watermark removed image",
239
+ elem_id="uploaded_image",
240
  )
241
  with gr.Row():
242
  name_input = gr.Textbox(
243
+ label="Your Name", placeholder="Anonymous"
244
  )
245
  upload_btn = gr.Button("Upload and Evaluate")
246
 
247
+ with gr.Tab(
248
+ "Evaluation Results and Your Ranking",
249
+ id="leaderboard",
250
+ elem_classes="gr-tab-header",
251
+ ):
252
+ gr.Markdown(
253
+ """
254
+ TODO: Add descriptions
255
+ """
256
+ )
257
  with gr.Column():
258
  leaderboard_plot = gr.Plot(
259
+ value=update_plot(get_submissions_from_redis()),
260
+ show_label=False,
261
+ elem_id="leaderboard_plot",
262
  )
263
  with gr.Row():
264
+ rank_output = gr.Textbox(label="Your Ranking")
265
  name_output = gr.Textbox(label="Your Name")
266
+ performance_output = gr.Textbox(
267
+ label="Watermark Performance (lower is better)"
268
+ )
269
+ quality_output = gr.Textbox(
270
+ label="Quality Degredation (lower is better)"
271
  )
272
+ overall_output = gr.Textbox(
273
+ label="Overall Score (lower is better)"
274
  )
275
 
276
  submit_btn.click(lambda: gr.Tabs(selected="submit"), None, tabs)
 
284
  name_output,
285
  performance_output,
286
  quality_output,
287
+ overall_output,
288
  ],
289
  )
290
 
291
  demo.load(
292
  lambda: [
293
  gr.Image(value="./image.png", height=512, width=512),
294
+ gr.Plot(update_plot(get_submissions_from_redis())),
295
  ],
296
  outputs=[original_image, leaderboard_plot],
297
  )
attacked_image.png DELETED
Binary file (155 kB)
 
kit/__init__.py CHANGED
@@ -81,7 +81,7 @@ def compute_quality(attacked_image, clean_image, quiet=True):
81
 
82
  # Compress the image
83
  buffer = io.BytesIO()
84
- attacked_image.save(buffer, format="JPEG", quality=90)
85
  buffer.seek(0)
86
 
87
  # Update attacked_image with the compressed version
 
81
 
82
  # Compress the image
83
  buffer = io.BytesIO()
84
+ attacked_image.save(buffer, format="JPEG", quality=95)
85
  buffer.seek(0)
86
 
87
  # Update attacked_image with the compressed version
kit/metrics/__init__.py CHANGED
@@ -13,9 +13,7 @@ from .image import (
13
  from .perceptual import (
14
  load_perceptual_models,
15
  compute_lpips,
16
- compute_watson,
17
  compute_lpips_repeated,
18
- compute_watson_repeated,
19
  compute_perceptual_metric_repeated,
20
  )
21
  from .aesthetics import (
 
13
  from .perceptual import (
14
  load_perceptual_models,
15
  compute_lpips,
 
16
  compute_lpips_repeated,
 
17
  compute_perceptual_metric_repeated,
18
  )
19
  from .aesthetics import (
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39a5d014670226d52c408e0dfec840b7626d80a73d003a6a144caafd5e02d031
3
+ size 19423219
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc48a8a2315cfdbc7bb8278be55f645e8a995e1a2fa234baec5eb41c4d33e070
3
+ size 17850319
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4a9481fdbce5ff02b252bcb25109b9f3b29841289fadf7e79e884d59f9357d5
3
+ size 16801743
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19b016304f54ae866e27f1eb498c0861f704958e7c37693adc5ce094e63904a8
3
+ size 19423099
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03603eee1864c2e5e97ef7079229609653db5b10594ca8b1de9e541d838cae9c
3
+ size 17850199
kit/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb7fe561369ab6c7dad34b9316a56d2c6070582f0323656148e1107a242cd666
3
+ size 16801623
kit/metrics/lpips/weights/v0.0/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
+ size 5455
kit/metrics/lpips/weights/v0.0/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
+ size 10057
kit/metrics/lpips/weights/v0.0/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
+ size 6735
kit/metrics/lpips/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009
kit/metrics/lpips/weights/v0.1/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
+ size 10811
kit/metrics/lpips/weights/v0.1/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
kit/metrics/perceptual.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  from PIL import Image
3
  from torchvision import transforms
4
  from .lpips import LPIPS
5
- from .watson import LossProvider
6
 
7
 
8
  # Normalize image tensors
@@ -33,19 +32,10 @@ def to_tensor(images, norm_type="naive"):
33
 
34
 
35
  def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
36
- assert metric_name in ["lpips", "watson"]
37
  if metric_name == "lpips":
38
  assert mode in ["vgg", "alex"]
39
  perceptual_model = LPIPS(net=mode).to(device)
40
- elif metric_name == "watson":
41
- assert mode in ["vgg", "dft", "fft"]
42
- perceptual_model = (
43
- LossProvider()
44
- .get_loss_function(
45
- "Watson-" + mode, colorspace="RGB", pretrained=True, reduction="none"
46
- )
47
- .to(device)
48
- )
49
  else:
50
  assert False
51
  return perceptual_model
@@ -65,12 +55,6 @@ def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")):
65
  return compute_metric(image1, image2, perceptual_model, device)
66
 
67
 
68
- # Compute Watson distance between two images
69
- def compute_watson(image1, image2, mode="dft", device=torch.device("cuda")):
70
- perceptual_model = load_perceptual_models("watson", mode, device)
71
- return compute_metric(image1, image2, perceptual_model, device)
72
-
73
-
74
  # Compute metrics between pairs of images
75
  def compute_perceptual_metric_repeated(
76
  images1,
@@ -107,16 +91,3 @@ def compute_lpips_repeated(
107
  return compute_perceptual_metric_repeated(
108
  images1, images2, "lpips", mode, model, device
109
  )
110
-
111
-
112
- # Compute Watson distance between pairs of images
113
- def compute_watson_repeated(
114
- images1,
115
- images2,
116
- mode="dft",
117
- model=None,
118
- device=torch.device("cuda"),
119
- ):
120
- return compute_perceptual_metric_repeated(
121
- images1, images2, "watson", mode, model, device
122
- )
 
2
  from PIL import Image
3
  from torchvision import transforms
4
  from .lpips import LPIPS
 
5
 
6
 
7
  # Normalize image tensors
 
32
 
33
 
34
  def load_perceptual_models(metric_name, mode, device=torch.device("cuda")):
35
+ assert metric_name in ["lpips"]
36
  if metric_name == "lpips":
37
  assert mode in ["vgg", "alex"]
38
  perceptual_model = LPIPS(net=mode).to(device)
 
 
 
 
 
 
 
 
 
39
  else:
40
  assert False
41
  return perceptual_model
 
55
  return compute_metric(image1, image2, perceptual_model, device)
56
 
57
 
 
 
 
 
 
 
58
  # Compute metrics between pairs of images
59
  def compute_perceptual_metric_repeated(
60
  images1,
 
91
  return compute_perceptual_metric_repeated(
92
  images1, images2, "lpips", mode, model, device
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- """
2
- From https://github.com/facebookresearch/stable_signature
3
- """
4
- from .loss_provider import LossProvider
 
 
 
 
 
kit/metrics/watson/color_wrapper.py DELETED
@@ -1,103 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class RGB2YCbCr(nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
- transf = torch.tensor(
10
- [[0.299, 0.587, 0.114], [-0.1687, -0.3313, 0.5], [0.5, -0.4187, -0.0813]]
11
- ).transpose(0, 1)
12
- self.transform = nn.Parameter(transf, requires_grad=False)
13
- bias = torch.tensor([0, 0.5, 0.5])
14
- self.bias = nn.Parameter(bias, requires_grad=False)
15
-
16
- def forward(self, rgb):
17
- N, C, H, W = rgb.shape
18
- assert C == 3
19
- rgb = rgb.transpose(1, 3)
20
- cbcr = torch.matmul(rgb, self.transform)
21
- cbcr += self.bias
22
- return cbcr.transpose(1, 3)
23
-
24
-
25
- class ColorWrapper(nn.Module):
26
- """
27
- Extension for single-channel loss to work on color images
28
- """
29
-
30
- def __init__(self, lossclass, args, kwargs, trainable=False):
31
- """
32
- Parameters:
33
- lossclass: class of the individual loss functions
34
- trainable: bool, if True parameters of the loss are trained.
35
- args: tuple, arguments for instantiation of loss fun
36
- kwargs: dict, key word arguments for instantiation of loss fun
37
- """
38
- super().__init__()
39
-
40
- # submodules
41
- self.add_module("to_YCbCr", RGB2YCbCr())
42
- self.add_module("ly", lossclass(*args, **kwargs))
43
- self.add_module("lcb", lossclass(*args, **kwargs))
44
- self.add_module("lcr", lossclass(*args, **kwargs))
45
-
46
- # weights
47
- self.w_tild = nn.Parameter(torch.zeros(3), requires_grad=trainable)
48
-
49
- @property
50
- def w(self):
51
- return F.softmax(self.w_tild, dim=0)
52
-
53
- def forward(self, input, target):
54
- # convert color space
55
- input = self.to_YCbCr(input)
56
- target = self.to_YCbCr(target)
57
-
58
- ly = self.ly(input[:, [0], :, :], target[:, [0], :, :])
59
- lcb = self.lcb(input[:, [1], :, :], target[:, [1], :, :])
60
- lcr = self.lcr(input[:, [2], :, :], target[:, [2], :, :])
61
-
62
- w = self.w
63
-
64
- return ly * w[0] + lcb * w[1] + lcr * w[2]
65
-
66
-
67
- class GreyscaleWrapper(nn.Module):
68
- """
69
- Maps 3 channel RGB or 1 channel greyscale input to 3 greyscale channels
70
- """
71
-
72
- def __init__(self, lossclass, args, kwargs):
73
- """
74
- Parameters:
75
- lossclass: class of the individual loss function
76
- args: tuple, arguments for instantiation of loss fun
77
- kwargs: dict, key word arguments for instantiation of loss fun
78
- """
79
- super().__init__()
80
-
81
- # submodules
82
- self.add_module("loss", lossclass(*args, **kwargs))
83
-
84
- def to_greyscale(self, tensor):
85
- return (
86
- tensor[:, [0], :, :] * 0.3
87
- + tensor[:, [1], :, :] * 0.59
88
- + tensor[:, [2], :, :] * 0.11
89
- )
90
-
91
- def forward(self, input, target):
92
- (N, C, X, Y) = input.size()
93
-
94
- if N == 3:
95
- # convert input to greyscale
96
- input = self.to_greyscale(input)
97
- target = self.to_greyscale(target)
98
-
99
- # input in now greyscale, expand to 3 channels
100
- input = input.expand(N, 3, X, Y)
101
- target = target.expand(N, 3, X, Y)
102
-
103
- return self.loss.forward(input, target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/dct2d.py DELETED
@@ -1,105 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- class Dct2d(nn.Module):
8
- """
9
- Blockwhise 2D DCT
10
- """
11
-
12
- def __init__(self, blocksize=8, interleaving=False):
13
- """
14
- Parameters:
15
- blocksize: int, size of the Blocks for discrete cosine transform
16
- interleaving: bool, should the blocks interleave?
17
- """
18
- super().__init__() # call super constructor
19
-
20
- self.blocksize = blocksize
21
- self.interleaving = interleaving
22
-
23
- if interleaving:
24
- self.stride = self.blocksize // 2
25
- else:
26
- self.stride = self.blocksize
27
-
28
- # precompute DCT weight matrix
29
- A = np.zeros((blocksize, blocksize))
30
- for i in range(blocksize):
31
- c_i = 1 / np.sqrt(2) if i == 0 else 1.0
32
- for n in range(blocksize):
33
- A[i, n] = (
34
- np.sqrt(2 / blocksize)
35
- * c_i
36
- * np.cos((2 * n + 1) / (blocksize * 2) * i * np.pi)
37
- )
38
-
39
- # set up conv layer
40
- self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32), requires_grad=False)
41
- self.unfold = torch.nn.Unfold(
42
- kernel_size=blocksize, padding=0, stride=self.stride
43
- )
44
- return
45
-
46
- def forward(self, x):
47
- """
48
- performs 2D blockwhise DCT
49
-
50
- Parameters:
51
- x: tensor of dimension (N, 1, h, w)
52
-
53
- Return:
54
- tensor of dimension (N, k, blocksize, blocksize)
55
- where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients
56
- """
57
-
58
- (N, C, H, W) = x.shape
59
- assert C == 1, "DCT is only implemented for a single channel"
60
- assert H >= self.blocksize, "Input too small for blocksize"
61
- assert W >= self.blocksize, "Input too small for blocksize"
62
- assert (H % self.stride == 0) and (
63
- W % self.stride == 0
64
- ), "FFT is only for dimensions divisible by the blocksize"
65
-
66
- # unfold to blocks
67
- x = self.unfold(x)
68
- # now shape (N, blocksize**2, k)
69
- (N, _, k) = x.shape
70
- x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2)
71
- # now shape (N, #k, blocksize, blocksize)
72
- # perform DCT
73
- coeff = self.A.matmul(x).matmul(self.A.transpose(0, 1))
74
-
75
- return coeff
76
-
77
- def inverse(self, coeff, output_shape):
78
- """
79
- performs 2D blockwhise iDCT
80
-
81
- Parameters:
82
- coeff: tensor of dimension (N, k, blocksize, blocksize)
83
- where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients
84
- output_shape: (h, w) dimensions of the reconstructed image
85
-
86
- Return:
87
- tensor of dimension (N, 1, h, w)
88
- """
89
- if self.interleaving:
90
- raise Exception(
91
- "Inverse block DCT is not implemented for interleaving blocks!"
92
- )
93
-
94
- # perform iDCT
95
- x = self.A.transpose(0, 1).matmul(coeff).matmul(self.A)
96
- (N, k, _, _) = x.shape
97
- x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k)
98
- x = F.fold(
99
- x,
100
- output_size=(output_shape[-2], output_shape[-1]),
101
- kernel_size=self.blocksize,
102
- padding=0,
103
- stride=self.blocksize,
104
- )
105
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/deep_loss.py DELETED
@@ -1,307 +0,0 @@
1
- # Deeploss function from Zhang et al. (2018)
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import models
5
- from collections import namedtuple
6
-
7
-
8
- class NetLinLayer(nn.Module):
9
- """A single linear layer which does a 1x1 conv"""
10
-
11
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
12
- super(NetLinLayer, self).__init__()
13
-
14
- layers = (
15
- [
16
- nn.Dropout(),
17
- ]
18
- if (use_dropout)
19
- else [
20
- nn.Dropout(p=0.0),
21
- ]
22
- )
23
- layers += [
24
- nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
25
- ]
26
- self.model = nn.Sequential(*layers)
27
-
28
-
29
- def normalize_tensor(in_feat, eps=1e-10):
30
- # norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1)).view(in_feat.size()[0],1,in_feat.size()[2],in_feat.size()[3]).repeat(1,in_feat.size()[1],1,1)
31
- norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
32
- in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
33
- )
34
- return in_feat / (norm_factor.expand_as(in_feat) + eps)
35
-
36
-
37
- class vgg16(torch.nn.Module):
38
- def __init__(self, requires_grad=False, pretrained=True):
39
- super(vgg16, self).__init__()
40
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
41
- self.slice1 = torch.nn.Sequential()
42
- self.slice2 = torch.nn.Sequential()
43
- self.slice3 = torch.nn.Sequential()
44
- self.slice4 = torch.nn.Sequential()
45
- self.slice5 = torch.nn.Sequential()
46
- self.N_slices = 5
47
- for x in range(4):
48
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
49
- for x in range(4, 9):
50
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
51
- for x in range(9, 16):
52
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
53
- for x in range(16, 23):
54
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
55
- for x in range(23, 30):
56
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
57
- if not requires_grad:
58
- for param in self.parameters():
59
- param.requires_grad = False
60
-
61
- def forward(self, X):
62
- h = self.slice1(X)
63
- h_relu1_2 = h
64
- h = self.slice2(h)
65
- h_relu2_2 = h
66
- h = self.slice3(h)
67
- h_relu3_3 = h
68
- h = self.slice4(h)
69
- h_relu4_3 = h
70
- h = self.slice5(h)
71
- h_relu5_3 = h
72
- vgg_outputs = namedtuple(
73
- "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
74
- )
75
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
76
- return out
77
-
78
-
79
- class squeezenet(torch.nn.Module):
80
- def __init__(self, requires_grad=False, pretrained=True):
81
- super(squeezenet, self).__init__()
82
- pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
83
- self.slice1 = torch.nn.Sequential()
84
- self.slice2 = torch.nn.Sequential()
85
- self.slice3 = torch.nn.Sequential()
86
- self.slice4 = torch.nn.Sequential()
87
- self.slice5 = torch.nn.Sequential()
88
- self.slice6 = torch.nn.Sequential()
89
- self.slice7 = torch.nn.Sequential()
90
- self.N_slices = 7
91
- for x in range(2):
92
- self.slice1.add_module(str(x), pretrained_features[x])
93
- for x in range(2, 5):
94
- self.slice2.add_module(str(x), pretrained_features[x])
95
- for x in range(5, 8):
96
- self.slice3.add_module(str(x), pretrained_features[x])
97
- for x in range(8, 10):
98
- self.slice4.add_module(str(x), pretrained_features[x])
99
- for x in range(10, 11):
100
- self.slice5.add_module(str(x), pretrained_features[x])
101
- for x in range(11, 12):
102
- self.slice6.add_module(str(x), pretrained_features[x])
103
- for x in range(12, 13):
104
- self.slice7.add_module(str(x), pretrained_features[x])
105
- if not requires_grad:
106
- for param in self.parameters():
107
- param.requires_grad = False
108
-
109
- def forward(self, X):
110
- h = self.slice1(X)
111
- h_relu1 = h
112
- h = self.slice2(h)
113
- h_relu2 = h
114
- h = self.slice3(h)
115
- h_relu3 = h
116
- h = self.slice4(h)
117
- h_relu4 = h
118
- h = self.slice5(h)
119
- h_relu5 = h
120
- h = self.slice6(h)
121
- h_relu6 = h
122
- h = self.slice7(h)
123
- h_relu7 = h
124
- vgg_outputs = namedtuple(
125
- "SqueezeOutputs",
126
- ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
127
- )
128
- out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
129
-
130
- return out
131
-
132
-
133
- class alexnet(torch.nn.Module):
134
- def __init__(self, requires_grad=False, pretrained=True):
135
- super(alexnet, self).__init__()
136
- alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
137
- self.slice1 = torch.nn.Sequential()
138
- self.slice2 = torch.nn.Sequential()
139
- self.slice3 = torch.nn.Sequential()
140
- self.slice4 = torch.nn.Sequential()
141
- self.slice5 = torch.nn.Sequential()
142
- self.N_slices = 5
143
- for x in range(2):
144
- self.slice1.add_module(str(x), alexnet_pretrained_features[x])
145
- for x in range(2, 5):
146
- self.slice2.add_module(str(x), alexnet_pretrained_features[x])
147
- for x in range(5, 8):
148
- self.slice3.add_module(str(x), alexnet_pretrained_features[x])
149
- for x in range(8, 10):
150
- self.slice4.add_module(str(x), alexnet_pretrained_features[x])
151
- for x in range(10, 12):
152
- self.slice5.add_module(str(x), alexnet_pretrained_features[x])
153
- if not requires_grad:
154
- for param in self.parameters():
155
- param.requires_grad = False
156
-
157
- def forward(self, X):
158
- h = self.slice1(X)
159
- h_relu1 = h
160
- h = self.slice2(h)
161
- h_relu2 = h
162
- h = self.slice3(h)
163
- h_relu3 = h
164
- h = self.slice4(h)
165
- h_relu4 = h
166
- h = self.slice5(h)
167
- h_relu5 = h
168
- alexnet_outputs = namedtuple(
169
- "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
170
- )
171
- out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
172
-
173
- return out
174
-
175
-
176
- class PNetLin(nn.Module):
177
- def __init__(
178
- self,
179
- pnet_type="vgg",
180
- pnet_rand=False,
181
- pnet_tune=False,
182
- use_dropout=True,
183
- use_gpu=True,
184
- spatial=False,
185
- version="0.1",
186
- colorspace="RGB",
187
- reduction="none",
188
- ):
189
- super(PNetLin, self).__init__()
190
-
191
- self.use_gpu = use_gpu
192
- self.pnet_type = pnet_type
193
- self.pnet_tune = pnet_tune
194
- self.pnet_rand = pnet_rand
195
- self.spatial = spatial
196
- self.version = version
197
- self.colorspace = colorspace
198
- self.reduction = reduction
199
-
200
- if self.pnet_type in ["vgg", "vgg16"]:
201
- net_type = vgg16
202
- self.chns = [64, 128, 256, 512, 512]
203
- elif self.pnet_type == "alex":
204
- net_type = alexnet
205
- self.chns = [64, 192, 384, 256, 256]
206
- elif self.pnet_type == "squeeze":
207
- net_type = squeezenet
208
- self.chns = [64, 128, 256, 384, 384, 512, 512]
209
-
210
- if self.pnet_tune:
211
- self.net = net_type(pretrained=not self.pnet_rand, requires_grad=True)
212
- else:
213
- self.net = [
214
- net_type(pretrained=not self.pnet_rand, requires_grad=False),
215
- ]
216
-
217
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
218
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
219
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
220
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
221
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
222
- self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
223
- if self.pnet_type == "squeeze": # 7 layers for squeezenet
224
- self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
225
- self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
226
- self.lins += [self.lin5, self.lin6]
227
-
228
- self.shift = torch.autograd.Variable(
229
- torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
230
- )
231
- self.scale = torch.autograd.Variable(
232
- torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
233
- )
234
-
235
- if use_gpu:
236
- if self.pnet_tune:
237
- self.net.cuda()
238
- else:
239
- self.net[0].cuda()
240
- self.shift = self.shift.cuda()
241
- self.scale = self.scale.cuda()
242
- self.lin0.cuda()
243
- self.lin1.cuda()
244
- self.lin2.cuda()
245
- self.lin3.cuda()
246
- self.lin4.cuda()
247
- if self.pnet_type == "squeeze":
248
- self.lin5.cuda()
249
- self.lin6.cuda()
250
-
251
- def forward(self, in0, in1):
252
- in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
253
- in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
254
-
255
- if self.colorspace == "Gray":
256
- in0_sc = util.tensor2tensorGrayscaleLazy(in0_sc)
257
- in1_sc = util.tensor2tensorGrayscaleLazy(in1_sc)
258
-
259
- if self.version == "0.0":
260
- # v0.0 - original release had a bug, where input was not scaled
261
- in0_input = in0
262
- in1_input = in1
263
- else:
264
- # v0.1
265
- in0_input = in0_sc
266
- in1_input = in1_sc
267
-
268
- if self.pnet_tune:
269
- outs0 = self.net.forward(in0_input)
270
- outs1 = self.net.forward(in1_input)
271
- else:
272
- outs0 = self.net[0].forward(in0_input)
273
- outs1 = self.net[0].forward(in1_input)
274
-
275
- feats0 = {}
276
- feats1 = {}
277
- diffs = [0] * len(outs0)
278
-
279
- for kk, out0 in enumerate(outs0):
280
- feats0[kk] = normalize_tensor(outs0[kk]) # norm NN outputs
281
- feats1[kk] = normalize_tensor(outs1[kk])
282
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 # squared diff
283
-
284
- if self.spatial:
285
- lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
286
- if self.pnet_type == "squeeze":
287
- lin_models.extend([self.lin5, self.lin6])
288
- res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))]
289
- return res
290
-
291
- val = torch.mean(
292
- torch.mean(self.lin0.model(diffs[0]), dim=3), dim=2
293
- ) # sum means over H, W
294
- val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]), dim=3), dim=2)
295
- val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]), dim=3), dim=2)
296
- val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]), dim=3), dim=2)
297
- val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]), dim=3), dim=2)
298
- if self.pnet_type == "squeeze":
299
- val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]), dim=3), dim=2)
300
- val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]), dim=3), dim=2)
301
-
302
- val = val.view(val.size()[0], val.size()[1], 1, 1)
303
-
304
- if self.reduction == "sum":
305
- val = torch.sum(val)
306
-
307
- return val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/loss_provider.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import os
4
- import warnings
5
- from .color_wrapper import ColorWrapper, GreyscaleWrapper
6
- from .shift_wrapper import ShiftWrapper
7
- from .watson import WatsonDistance
8
- from .watson_fft import WatsonDistanceFft
9
- from .watson_vgg import WatsonDistanceVgg
10
- from .deep_loss import PNetLin
11
- from .ssim import SSIM
12
-
13
-
14
- class LossProvider:
15
- def __init__(self):
16
- self.loss_functions = [
17
- "L1",
18
- "L2",
19
- "SSIM",
20
- "Watson-dct",
21
- "Watson-fft",
22
- "Watson-vgg",
23
- "Deeploss-vgg",
24
- "Deeploss-squeeze",
25
- "Adaptive",
26
- ]
27
- self.color_models = ["LA", "RGB"]
28
-
29
- def load_state_dict(self, filename):
30
- current_dir = os.path.dirname(__file__)
31
- path = os.path.join(current_dir, "weights", filename)
32
- return torch.load(path, map_location="cpu")
33
-
34
- def get_loss_function(
35
- self,
36
- model,
37
- colorspace="RGB",
38
- reduction="sum",
39
- deterministic=False,
40
- pretrained=True,
41
- image_size=None,
42
- ):
43
- """
44
- returns a trained loss class.
45
- model: one of the values returned by self.loss_functions
46
- colorspace: 'LA' or 'RGB'
47
- deterministic: bool, if false (default) uses shifting of image blocks for watson-fft
48
- image_size: tuple, size of input images. Only required for adaptive loss. Eg: [3, 64, 64]
49
- """
50
- warnings.filterwarnings("ignore")
51
- is_greyscale = colorspace in ["grey", "Grey", "LA", "greyscale", "grey-scale"]
52
-
53
- if model.lower() in ["l2"]:
54
- loss = nn.MSELoss(reduction=reduction)
55
- elif model.lower() in ["l1"]:
56
- loss = nn.L1Loss(reduction=reduction)
57
- elif model.lower() in ["ssim"]:
58
- loss = SSIM(size_average=(reduction in ["sum", "mean"]))
59
- elif model.lower() in ["watson", "watson-dct"]:
60
- if is_greyscale:
61
- if deterministic:
62
- loss = WatsonDistance(reduction=reduction)
63
- if pretrained:
64
- loss.load_state_dict(
65
- self.load_state_dict("gray_watson_dct_trial0.pth")
66
- )
67
- else:
68
- loss = ShiftWrapper(WatsonDistance, (), {"reduction": reduction})
69
- if pretrained:
70
- loss.loss.load_state_dict(
71
- self.load_state_dict("gray_watson_dct_trial0.pth")
72
- )
73
- else:
74
- if deterministic:
75
- loss = ColorWrapper(WatsonDistance, (), {"reduction": reduction})
76
- if pretrained:
77
- loss.load_state_dict(
78
- self.load_state_dict("rgb_watson_dct_trial0.pth")
79
- )
80
- else:
81
- loss = ShiftWrapper(
82
- ColorWrapper, (WatsonDistance, (), {"reduction": reduction}), {}
83
- )
84
- if pretrained:
85
- loss.loss.load_state_dict(
86
- self.load_state_dict("rgb_watson_dct_trial0.pth")
87
- )
88
- elif model.lower() in ["watson-fft", "watson-dft"]:
89
- if is_greyscale:
90
- if deterministic:
91
- loss = WatsonDistanceFft(reduction=reduction)
92
- if pretrained:
93
- loss.load_state_dict(
94
- self.load_state_dict("gray_watson_fft_trial0.pth")
95
- )
96
- else:
97
- loss = ShiftWrapper(WatsonDistanceFft, (), {"reduction": reduction})
98
- if pretrained:
99
- loss.loss.load_state_dict(
100
- self.load_state_dict("gray_watson_fft_trial0.pth")
101
- )
102
- else:
103
- if deterministic:
104
- loss = ColorWrapper(WatsonDistanceFft, (), {"reduction": reduction})
105
- if pretrained:
106
- loss.load_state_dict(
107
- self.load_state_dict("rgb_watson_fft_trial0.pth")
108
- )
109
- else:
110
- loss = ShiftWrapper(
111
- ColorWrapper,
112
- (WatsonDistanceFft, (), {"reduction": reduction}),
113
- {},
114
- )
115
- if pretrained:
116
- loss.loss.load_state_dict(
117
- self.load_state_dict("rgb_watson_fft_trial0.pth")
118
- )
119
- elif model.lower() in ["watson-vgg", "watson-deep"]:
120
- if is_greyscale:
121
- loss = GreyscaleWrapper(WatsonDistanceVgg, (), {"reduction": reduction})
122
- if pretrained:
123
- loss.loss.load_state_dict(
124
- self.load_state_dict("gray_watson_vgg_trial0.pth")
125
- )
126
- else:
127
- loss = WatsonDistanceVgg(reduction=reduction)
128
- if pretrained:
129
- loss.load_state_dict(
130
- self.load_state_dict("rgb_watson_vgg_trial0.pth")
131
- )
132
- elif model.lower() in ["deeploss-vgg"]:
133
- if is_greyscale:
134
- loss = GreyscaleWrapper(
135
- PNetLin,
136
- (),
137
- {"pnet_type": "vgg", "reduction": reduction, "use_dropout": False},
138
- )
139
- if pretrained:
140
- loss.loss.load_state_dict(
141
- self.load_state_dict("gray_pnet_lin_vgg_trial0.pth")
142
- )
143
- else:
144
- loss = PNetLin(pnet_type="vgg", reduction=reduction, use_dropout=False)
145
- if pretrained:
146
- loss.load_state_dict(
147
- self.load_state_dict("rgb_pnet_lin_vgg_trial0.pth")
148
- )
149
- elif model.lower() in ["deeploss-squeeze"]:
150
- if is_greyscale:
151
- loss = GreyscaleWrapper(
152
- PNetLin,
153
- (),
154
- {
155
- "pnet_type": "squeeze",
156
- "reduction": reduction,
157
- "use_dropout": False,
158
- },
159
- )
160
- if pretrained:
161
- loss.loss.load_state_dict(
162
- self.load_state_dict("gray_pnet_lin_squeeze_trial0.pth")
163
- )
164
- else:
165
- loss = PNetLin(
166
- pnet_type="squeeze", reduction=reduction, use_dropout=False
167
- )
168
- if pretrained:
169
- loss.load_state_dict(
170
- self.load_state_dict("rgb_pnet_lin_squeeze_trial0.pth")
171
- )
172
- else:
173
- raise Exception('Metric "{}" not implemented'.format(model))
174
-
175
- # freeze all training of the loss functions
176
- if pretrained:
177
- for param in loss.parameters():
178
- param.requires_grad = False
179
-
180
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/rfft2d.py DELETED
@@ -1,87 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.fft as fft
4
- import torch.nn.functional as F
5
-
6
-
7
- class Rfft2d(nn.Module):
8
- """
9
- Blockwhise 2D FFT
10
- for fixed blocksize of 8x8
11
- """
12
-
13
- def __init__(self, blocksize=8, interleaving=False):
14
- """
15
- Parameters:
16
- """
17
- super().__init__() # call super constructor
18
-
19
- self.blocksize = blocksize
20
- self.interleaving = interleaving
21
- if interleaving:
22
- self.stride = self.blocksize // 2
23
- else:
24
- self.stride = self.blocksize
25
-
26
- self.unfold = torch.nn.Unfold(
27
- kernel_size=self.blocksize, padding=0, stride=self.stride
28
- )
29
- return
30
-
31
- def forward(self, x):
32
- """
33
- performs 2D blockwhise DCT
34
-
35
- Parameters:
36
- x: tensor of dimension (N, 1, h, w)
37
-
38
- Return:
39
- tensor of dimension (N, k, b, b/2, 2)
40
- where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients.
41
- The last dimension is pytorches representation of complex values
42
- """
43
-
44
- (N, C, H, W) = x.shape
45
- assert C == 1, "FFT is only implemented for a single channel"
46
- assert H >= self.blocksize, "Input too small for blocksize"
47
- assert W >= self.blocksize, "Input too small for blocksize"
48
- assert (H % self.stride == 0) and (
49
- W % self.stride == 0
50
- ), "FFT is only for dimensions divisible by the blocksize"
51
-
52
- # unfold to blocks
53
- x = self.unfold(x)
54
- # now shape (N, 64, k)
55
- (N, _, k) = x.shape
56
- x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2)
57
- # now shape (N, #k, b, b)
58
- # perform DCT
59
- coeff = fft.rfft(x)
60
- coeff = torch.view_as_real(coeff)
61
-
62
- return coeff / self.blocksize**2
63
-
64
- def inverse(self, coeff, output_shape):
65
- """
66
- performs 2D blockwhise inverse rFFT
67
-
68
- Parameters:
69
- output_shape: Tuple, dimensions of the outpus sample
70
- """
71
- if self.interleaving:
72
- raise Exception(
73
- "Inverse block FFT is not implemented for interleaving blocks!"
74
- )
75
-
76
- # perform iRFFT
77
- x = fft.irfft(coeff, dim=2, signal_sizes=(self.blocksize, self.blocksize))
78
- (N, k, _, _) = x.shape
79
- x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k)
80
- x = F.fold(
81
- x,
82
- output_size=(output_shape[-2], output_shape[-1]),
83
- kernel_size=self.blocksize,
84
- padding=0,
85
- stride=self.blocksize,
86
- )
87
- return x * (self.blocksize**2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/shift_wrapper.py DELETED
@@ -1,51 +0,0 @@
1
- import torch.nn as nn
2
- import numpy as np
3
-
4
-
5
- class ShiftWrapper(nn.Module):
6
- """
7
- Extension for 2-dimensional inout loss functions.
8
- Shifts the inputs by up to 4 pixels. Uses replication padding.
9
- """
10
-
11
- def __init__(self, lossclass, args, kwargs):
12
- """
13
- Parameters:
14
- lossclass: class of the individual loss functions
15
- trainable: bool, if True parameters of the loss are trained.
16
- args: tuple, arguments for instantiation of loss fun
17
- kwargs: dict, key word arguments for instantiation of loss fun
18
- """
19
- super().__init__()
20
-
21
- # submodules
22
- self.add_module("loss", lossclass(*args, **kwargs))
23
-
24
- # shift amount
25
- self.max_shift = 8
26
-
27
- # padding
28
- self.pad = nn.ReplicationPad2d(self.max_shift // 2)
29
-
30
- def forward(self, input, target):
31
- # convert color space
32
- input = self.pad(input)
33
- target = self.pad(target)
34
-
35
- shift_x = np.random.randint(self.max_shift)
36
- shift_y = np.random.randint(self.max_shift)
37
-
38
- input = input[
39
- :,
40
- :,
41
- shift_x : -(self.max_shift - shift_x),
42
- shift_y : -(self.max_shift - shift_y),
43
- ]
44
- target = target[
45
- :,
46
- :,
47
- shift_x : -(self.max_shift - shift_x),
48
- shift_y : -(self.max_shift - shift_y),
49
- ]
50
-
51
- return self.loss(input, target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/ssim.py DELETED
@@ -1,95 +0,0 @@
1
- # SSIM implementation from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
2
- import torch
3
- import torch.nn.functional as F
4
- from torch.autograd import Variable
5
- from math import exp
6
-
7
-
8
- def gaussian(window_size, sigma):
9
- gauss = torch.Tensor(
10
- [
11
- exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
12
- for x in range(window_size)
13
- ]
14
- )
15
- return gauss / gauss.sum()
16
-
17
-
18
- def create_window(window_size, channel):
19
- _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
20
- _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
21
- window = Variable(
22
- _2D_window.expand(channel, 1, window_size, window_size).contiguous()
23
- )
24
- return window
25
-
26
-
27
- def _ssim(img1, img2, window, window_size, channel, size_average=True):
28
- mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
29
- mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
30
-
31
- mu1_sq = mu1.pow(2)
32
- mu2_sq = mu2.pow(2)
33
- mu1_mu2 = mu1 * mu2
34
-
35
- sigma1_sq = (
36
- F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
37
- )
38
- sigma2_sq = (
39
- F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
40
- )
41
- sigma12 = (
42
- F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
43
- - mu1_mu2
44
- )
45
-
46
- C1 = 0.01**2
47
- C2 = 0.03**2
48
-
49
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
50
- (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
51
- )
52
-
53
- if size_average:
54
- return ssim_map.mean()
55
- else:
56
- return ssim_map.mean(1).mean(1).mean(1)
57
-
58
-
59
- class SSIM(torch.nn.Module):
60
- def __init__(self, window_size=11, size_average=True):
61
- super(SSIM, self).__init__()
62
- self.window_size = window_size
63
- self.size_average = size_average
64
- self.channel = 1
65
- self.window = create_window(window_size, self.channel)
66
-
67
- def forward(self, img1, img2):
68
- (_, channel, _, _) = img1.size()
69
-
70
- if channel == self.channel and self.window.data.type() == img1.data.type():
71
- window = self.window
72
- else:
73
- window = create_window(self.window_size, channel)
74
-
75
- if img1.is_cuda:
76
- window = window.cuda(img1.get_device())
77
- window = window.type_as(img1)
78
-
79
- self.window = window
80
- self.channel = channel
81
-
82
- return 1 - _ssim(
83
- img1, img2, window, self.window_size, channel, self.size_average
84
- )
85
-
86
-
87
- def ssim(img1, img2, window_size=11, size_average=True):
88
- (_, channel, _, _) = img1.size()
89
- window = create_window(window_size, channel)
90
-
91
- if img1.is_cuda:
92
- window = window.cuda(img1.get_device())
93
- window = window.type_as(img1)
94
-
95
- return _ssim(img1, img2, window, window_size, channel, size_average)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/watson.py DELETED
@@ -1,123 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from .dct2d import Dct2d
5
-
6
- EPS = 1e-10
7
-
8
-
9
- def softmax(a, b, factor=1):
10
- concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
11
- softmax_factors = F.softmax(concat * factor, dim=-1)
12
- return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
13
-
14
-
15
- class WatsonDistance(nn.Module):
16
- """
17
- Loss function based on Watsons perceptual distance.
18
- Based on DCT quantization
19
- """
20
-
21
- def __init__(self, blocksize=8, trainable=False, reduction="sum"):
22
- """
23
- Parameters:
24
- blocksize: int, size of the Blocks for discrete cosine transform
25
- trainable: bool, if True parameters of the loss are trained and dropout is enabled.
26
- reduction: 'sum' or 'none', determines return format
27
- """
28
- super().__init__()
29
-
30
- # input mapping
31
- blocksize = torch.as_tensor(blocksize)
32
-
33
- # module to perform 2D blockwise DCT
34
- self.add_module("dct", Dct2d(blocksize=blocksize.item(), interleaving=False))
35
-
36
- # parameters, initialized with values from watson paper
37
- self.blocksize = nn.Parameter(blocksize, requires_grad=False)
38
- if self.blocksize == 8:
39
- # init with Jpeg QM
40
- self.t_tild = nn.Parameter(
41
- torch.log(
42
- torch.tensor( # log-scaled weights
43
- [
44
- [1.40, 1.01, 1.16, 1.66, 2.40, 3.43, 4.79, 6.56],
45
- [1.01, 1.45, 1.32, 1.52, 2.00, 2.71, 3.67, 4.93],
46
- [1.16, 1.32, 2.24, 2.59, 2.98, 3.64, 4.60, 5.88],
47
- [1.66, 1.52, 2.59, 3.77, 4.55, 5.30, 6.28, 7.60],
48
- [2.40, 2.00, 2.98, 4.55, 6.15, 7.46, 8.71, 10.17],
49
- [3.43, 2.71, 3.64, 5.30, 7.46, 9.62, 11.58, 13.51],
50
- [4.79, 3.67, 4.60, 6.28, 8.71, 11.58, 14.50, 17.29],
51
- [6.56, 4.93, 5.88, 7.60, 10.17, 13.51, 17.29, 21.15],
52
- ]
53
- )
54
- ),
55
- requires_grad=trainable,
56
- )
57
- else:
58
- # init with uniform QM
59
- self.t_tild = nn.Parameter(
60
- torch.zeros((self.blocksize, self.blocksize)), requires_grad=trainable
61
- )
62
-
63
- # other default parameters
64
- self.alpha = nn.Parameter(
65
- torch.tensor(0.649), requires_grad=trainable
66
- ) # luminance masking
67
- w = torch.tensor(0.7) # contrast masking
68
- self.w_tild = nn.Parameter(
69
- torch.log(w / (1 - w)), requires_grad=trainable
70
- ) # inverse of sigmoid
71
- self.beta = nn.Parameter(torch.tensor(4.0), requires_grad=trainable) # pooling
72
-
73
- # dropout for training
74
- self.dropout = nn.Dropout(0.5 if trainable else 0)
75
-
76
- # reduction
77
- self.reduction = reduction
78
- if reduction not in ["sum", "none"]:
79
- raise Exception(
80
- 'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
81
- reduction
82
- )
83
- )
84
-
85
- @property
86
- def t(self):
87
- # returns QM
88
- qm = torch.exp(self.t_tild)
89
- return qm
90
-
91
- @property
92
- def w(self):
93
- # return luminance masking parameter
94
- return torch.sigmoid(self.w_tild)
95
-
96
- def forward(self, input, target):
97
- # dct
98
- c0 = self.dct(target)
99
- c1 = self.dct(input)
100
-
101
- N, K, B, B = c0.shape
102
-
103
- # luminance masking
104
- avg_lum = torch.mean(c0[:, :, 0, 0])
105
- t_l = self.t.view(1, 1, B, B).expand(N, K, B, B)
106
- t_l = t_l * (((c0[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(
107
- N, K, 1, 1
108
- )
109
-
110
- # contrast masking
111
- s = softmax(t_l, (c0.abs() + EPS) ** self.w * t_l ** (1 - self.w))
112
-
113
- # pooling
114
- watson_dist = (((c0 - c1) / s).abs() + EPS) ** self.beta
115
- watson_dist = self.dropout(watson_dist) + EPS
116
- watson_dist = torch.sum(watson_dist, dim=(1, 2, 3))
117
- watson_dist = watson_dist ** (1 / self.beta)
118
-
119
- # reduction
120
- if self.reduction == "sum":
121
- watson_dist = torch.sum(watson_dist)
122
-
123
- return watson_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/watson_fft.py DELETED
@@ -1,139 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from .rfft2d import Rfft2d
5
-
6
- EPS = 1e-10
7
-
8
-
9
- def softmax(a, b, factor=1):
10
- concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
11
- softmax_factors = F.softmax(concat * factor, dim=-1)
12
- return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
13
-
14
-
15
- class WatsonDistanceFft(nn.Module):
16
- """
17
- Loss function based on Watsons perceptual distance.
18
- Based on FFT quantization
19
- """
20
-
21
- def __init__(self, blocksize=8, trainable=False, reduction="sum"):
22
- """
23
- Parameters:
24
- blocksize: int, size of the Blocks for discrete cosine transform
25
- trainable: bool, if True parameters of the loss are trained and dropout is enabled.
26
- reduction: 'sum' or 'none', determines return format
27
- """
28
- super().__init__()
29
- self.trainable = trainable
30
-
31
- # input mapping
32
- blocksize = torch.as_tensor(blocksize)
33
-
34
- # module to perform 2D blockwise rFFT
35
- self.add_module("fft", Rfft2d(blocksize=blocksize.item(), interleaving=False))
36
-
37
- # parameters
38
- self.weight_size = (blocksize, blocksize // 2 + 1)
39
- self.blocksize = nn.Parameter(blocksize, requires_grad=False)
40
- # init with uniform QM
41
- self.t_tild = nn.Parameter(
42
- torch.zeros(self.weight_size), requires_grad=trainable
43
- )
44
- self.alpha = nn.Parameter(
45
- torch.tensor(0.1), requires_grad=trainable
46
- ) # luminance masking
47
- w = torch.tensor(0.2) # contrast masking
48
- self.w_tild = nn.Parameter(
49
- torch.log(w / (1 - w)), requires_grad=trainable
50
- ) # inverse of sigmoid
51
- self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling
52
-
53
- # phase weights
54
- self.w_phase_tild = nn.Parameter(
55
- torch.zeros(self.weight_size) - 2.0, requires_grad=trainable
56
- )
57
-
58
- # dropout for training
59
- self.dropout = nn.Dropout(0.5 if trainable else 0)
60
-
61
- # reduction
62
- self.reduction = reduction
63
- if reduction not in ["sum", "none"]:
64
- raise Exception(
65
- 'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
66
- reduction
67
- )
68
- )
69
-
70
- @property
71
- def t(self):
72
- # returns QM
73
- qm = torch.exp(self.t_tild)
74
- return qm
75
-
76
- @property
77
- def w(self):
78
- # return luminance masking parameter
79
- return torch.sigmoid(self.w_tild)
80
-
81
- @property
82
- def w_phase(self):
83
- # return weights for phase
84
- w_phase = torch.exp(self.w_phase_tild)
85
- # set weights of non-phases to 0
86
- if not self.trainable:
87
- w_phase[0, 0] = 0.0
88
- w_phase[0, self.weight_size[1] - 1] = 0.0
89
- w_phase[self.weight_size[1] - 1, self.weight_size[1] - 1] = 0.0
90
- w_phase[self.weight_size[1] - 1, 0] = 0.0
91
- return w_phase
92
-
93
- def forward(self, input, target):
94
- # fft
95
- c0 = self.fft(target)
96
- c1 = self.fft(input)
97
-
98
- N, K, H, W, _ = c0.shape
99
-
100
- # get amplitudes
101
- c0_amp = torch.norm(c0 + EPS, p="fro", dim=4)
102
- c1_amp = torch.norm(c1 + EPS, p="fro", dim=4)
103
-
104
- # luminance masking
105
- avg_lum = torch.mean(c0_amp[:, :, 0, 0])
106
- t_l = self.t.view(1, 1, H, W).expand(N, K, H, W)
107
- t_l = t_l * (((c0_amp[:, :, 0, 0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(
108
- N, K, 1, 1
109
- )
110
-
111
- # contrast masking
112
- s = softmax(t_l, (c0_amp.abs() + EPS) ** self.w * t_l ** (1 - self.w))
113
-
114
- # pooling
115
- watson_dist = (((c0_amp - c1_amp) / s).abs() + EPS) ** self.beta
116
- watson_dist = self.dropout(watson_dist) + EPS
117
- watson_dist = torch.sum(watson_dist, dim=(1, 2, 3))
118
- watson_dist = watson_dist ** (1 / self.beta)
119
-
120
- # get phases
121
- c0_phase = torch.atan2(c0[:, :, :, :, 1], c0[:, :, :, :, 0] + EPS)
122
- c1_phase = torch.atan2(c1[:, :, :, :, 1], c1[:, :, :, :, 0] + EPS)
123
-
124
- # angular distance
125
- phase_dist = (
126
- torch.acos(torch.cos(c0_phase - c1_phase) * (1 - EPS * 10**3))
127
- * self.w_phase
128
- ) # we multiply with a factor ->1 to prevent taking the gradient of acos(-1) or acos(1). The gradient in this case would be -/+ inf
129
- phase_dist = self.dropout(phase_dist)
130
- phase_dist = torch.sum(phase_dist, dim=(1, 2, 3))
131
-
132
- # perceptual distance
133
- distance = watson_dist + phase_dist
134
-
135
- # reduce
136
- if self.reduction == "sum":
137
- distance = torch.sum(distance)
138
-
139
- return distance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/metrics/watson/watson_vgg.py DELETED
@@ -1,202 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision
5
-
6
- EPS = 1e-10
7
-
8
-
9
- class VggFeatureExtractor(nn.Module):
10
- def __init__(self):
11
- super(VggFeatureExtractor, self).__init__()
12
-
13
- # download vgg
14
- vgg16 = torchvision.models.vgg16(pretrained=True).features
15
-
16
- # set non trainable
17
- for param in vgg16.parameters():
18
- param.requires_grad = False
19
-
20
- # slice model
21
- self.slice1 = torch.nn.Sequential()
22
- self.slice2 = torch.nn.Sequential()
23
- self.slice3 = torch.nn.Sequential()
24
- self.slice4 = torch.nn.Sequential()
25
- self.slice5 = torch.nn.Sequential()
26
-
27
- for x in range(4): # conv relu conv relu
28
- self.slice1.add_module(str(x), vgg16[x])
29
- for x in range(4, 9): # max conv relu conv relu
30
- self.slice2.add_module(str(x), vgg16[x])
31
- for x in range(9, 16): # max cov relu conv relu conv relu
32
- self.slice3.add_module(str(x), vgg16[x])
33
- for x in range(16, 23): # conv relu max conv relu conv relu
34
- self.slice4.add_module(str(x), vgg16[x])
35
- for x in range(23, 30): # conv relu conv relu max conv relu
36
- self.slice5.add_module(str(x), vgg16[x])
37
-
38
- def forward(self, X):
39
- h = self.slice1(X)
40
- h_relu1_2 = h
41
- h = self.slice2(h)
42
- h_relu2_2 = h
43
- h = self.slice3(h)
44
- h_relu3_3 = h
45
- h = self.slice4(h)
46
- h_relu4_3 = h
47
- h = self.slice5(h)
48
- h_relu5_3 = h
49
-
50
- return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
51
-
52
-
53
- def normalize_tensor(t):
54
- # norms a tensor over the channel dimension to an euclidean length of 1.
55
- N, C, H, W = t.shape
56
- norm_factor = torch.sqrt(torch.sum(t**2, dim=1)).view(N, 1, H, W)
57
- return t / (norm_factor.expand_as(t) + EPS)
58
-
59
-
60
- def softmax(a, b, factor=1):
61
- concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1)
62
- softmax_factors = F.softmax(concat * factor, dim=-1)
63
- return a * softmax_factors[:, :, :, :, 0] + b * softmax_factors[:, :, :, :, 1]
64
-
65
-
66
- class WatsonDistanceVgg(nn.Module):
67
- """
68
- Loss function based on Watsons perceptual distance.
69
- Based on deep feature extraction
70
- """
71
-
72
- def __init__(self, trainable=False, reduction="sum"):
73
- """
74
- Parameters:
75
- trainable: bool, if True parameters of the loss are trained and dropout is enabled.
76
- reduction: 'sum' or 'none', determines return format
77
- """
78
- super().__init__()
79
-
80
- # module to perform feature extraction
81
- self.add_module("vgg", VggFeatureExtractor())
82
-
83
- # imagenet-normalization
84
- self.shift = nn.Parameter(
85
- torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1), requires_grad=False
86
- )
87
- self.scale = nn.Parameter(
88
- torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1), requires_grad=False
89
- )
90
-
91
- # channel dimensions
92
- self.L = 5
93
- self.channels = [64, 128, 256, 512, 512]
94
-
95
- # sensitivity parameters
96
- self.t0_tild = nn.Parameter(
97
- torch.zeros((self.channels[0])), requires_grad=trainable
98
- )
99
- self.t1_tild = nn.Parameter(
100
- torch.zeros((self.channels[1])), requires_grad=trainable
101
- )
102
- self.t2_tild = nn.Parameter(
103
- torch.zeros((self.channels[2])), requires_grad=trainable
104
- )
105
- self.t3_tild = nn.Parameter(
106
- torch.zeros((self.channels[3])), requires_grad=trainable
107
- )
108
- self.t4_tild = nn.Parameter(
109
- torch.zeros((self.channels[4])), requires_grad=trainable
110
- )
111
-
112
- # other default parameters
113
- w = torch.tensor(0.2) # contrast masking
114
- self.w0_tild = nn.Parameter(
115
- torch.log(w / (1 - w)), requires_grad=trainable
116
- ) # inverse of sigmoid
117
- self.w1_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
118
- self.w2_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
119
- self.w3_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
120
- self.w4_tild = nn.Parameter(torch.log(w / (1 - w)), requires_grad=trainable)
121
- self.beta = nn.Parameter(torch.tensor(1.0), requires_grad=trainable) # pooling
122
-
123
- # dropout for training
124
- self.dropout = nn.Dropout(0.5 if trainable else 0)
125
-
126
- # reduction
127
- self.reduction = reduction
128
- if reduction not in ["sum", "none"]:
129
- raise Exception(
130
- 'Reduction "{}" not supported. Valid values are: "sum", "none".'.format(
131
- reduction
132
- )
133
- )
134
-
135
- @property
136
- def t(self):
137
- return [
138
- torch.exp(t)
139
- for t in [
140
- self.t0_tild,
141
- self.t1_tild,
142
- self.t2_tild,
143
- self.t3_tild,
144
- self.t4_tild,
145
- ]
146
- ]
147
-
148
- @property
149
- def w(self):
150
- # return luminance masking parameter
151
- return [
152
- torch.sigmoid(w)
153
- for w in [
154
- self.w0_tild,
155
- self.w1_tild,
156
- self.w2_tild,
157
- self.w3_tild,
158
- self.w4_tild,
159
- ]
160
- ]
161
-
162
- def forward(self, input, target):
163
- # normalization
164
- input = (input - self.shift.expand_as(input)) / self.scale.expand_as(input)
165
- target = (target - self.shift.expand_as(target)) / self.scale.expand_as(target)
166
-
167
- # feature extraction
168
- c0 = self.vgg(target)
169
- c1 = self.vgg(input)
170
-
171
- # norm over channels
172
- for l in range(self.L):
173
- c0[l] = normalize_tensor(c0[l])
174
- c1[l] = normalize_tensor(c1[l])
175
-
176
- # contrast masking
177
- t = self.t
178
- w = self.w
179
- s = []
180
- for l in range(self.L):
181
- N, C_l, H_l, W_l = c0[l].shape
182
- t_l = t[l].view(1, C_l, 1, 1).expand(N, C_l, H_l, W_l)
183
- s.append(softmax(t_l, (c0[l].abs() + EPS) ** w[l] * t_l ** (1 - w[l])))
184
-
185
- # pooling
186
- watson_dist = 0
187
- for l in range(self.L):
188
- _, _, H_l, W_l = c0[l].shape
189
- layer_dist = (((c0[l] - c1[l]) / s[l]).abs() + EPS) ** self.beta
190
- layer_dist = self.dropout(layer_dist) + EPS
191
- layer_dist = torch.sum(
192
- layer_dist, dim=(1, 2, 3)
193
- ) # sum over dimensions of layer
194
- layer_dist = (1 / (H_l * W_l)) * layer_dist # normalize by layer size
195
- watson_dist += layer_dist # sum over layers
196
- watson_dist = watson_dist ** (1 / self.beta)
197
-
198
- # reduction
199
- if self.reduction == "sum":
200
- watson_dist = torch.sum(watson_dist)
201
-
202
- return watson_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kit/models/stable_signature.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b58841ab09f23e89acf5aedade09c7f65908ae33437c5242ad987d99b5cd2c1
3
+ size 1228161
requirements.txt CHANGED
@@ -5,6 +5,7 @@ torchvision
5
  transformers
6
  open_clip_torch
7
  numpy
 
8
  Pillow
9
  redis
10
  plotly
 
5
  transformers
6
  open_clip_torch
7
  numpy
8
+ scipy
9
  Pillow
10
  redis
11
  plotly