ZhengPeng7
commited on
Commit
•
e88de74
1
Parent(s):
fa9100b
Dtype adaptability between FP32 and FP16 in inference.
Browse files- birefnet.py +1 -1
birefnet.py
CHANGED
@@ -992,7 +992,7 @@ class BasicLayer(nn.Module):
|
|
992 |
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
993 |
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
994 |
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
995 |
-
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
996 |
|
997 |
for blk in self.blocks:
|
998 |
blk.H, blk.W = H, W
|
|
|
992 |
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
993 |
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
994 |
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
995 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
|
996 |
|
997 |
for blk in self.blocks:
|
998 |
blk.H, blk.W = H, W
|