Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class SignWithSigmoidGrad(torch.autograd.Function): | |
def forward(ctx, x): | |
result = (x > 0).float() | |
sigmoid_result = torch.sigmoid(x) | |
ctx.save_for_backward(sigmoid_result) | |
return result | |
def backward(ctx, grad_result): | |
(sigmoid_result,) = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_input = grad_result * sigmoid_result * (1 - sigmoid_result) | |
else: | |
grad_input = None | |
return grad_input | |
class Painter(nn.Module): | |
def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3): | |
super().__init__() | |
self.enc_img = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(3, 32, 3, 1), | |
nn.BatchNorm2d(32), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(32, 64, 3, 2), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(64, 128, 3, 2), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True)) | |
self.enc_canvas = nn.Sequential( | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(3, 32, 3, 1), | |
nn.BatchNorm2d(32), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(32, 64, 3, 2), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(64, 128, 3, 2), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True)) | |
self.conv = nn.Conv2d(128 * 2, hidden_dim, 1) | |
self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers) | |
self.linear_param = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(True), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(True), | |
nn.Linear(hidden_dim, param_per_stroke)) | |
self.linear_decider = nn.Linear(hidden_dim, 1) | |
self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim)) | |
self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2)) | |
self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2)) | |
def forward(self, img, canvas): | |
b, _, H, W = img.shape | |
img_feat = self.enc_img(img) | |
canvas_feat = self.enc_canvas(canvas) | |
h, w = img_feat.shape[-2:] | |
feat = torch.cat([img_feat, canvas_feat], dim=1) | |
feat_conv = self.conv(feat) | |
pos_embed = torch.cat([ | |
self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1), | |
self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1), | |
], dim=-1).flatten(0, 1).unsqueeze(1) | |
hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(), | |
self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1)) | |
hidden_state = hidden_state.permute(1, 0, 2).contiguous() | |
param = self.linear_param(hidden_state) | |
decision = self.linear_decider(hidden_state) | |
return param, decision | |