Spaces:
Runtime error
Runtime error
import time | |
import itertools | |
from dataset import ImageFolder | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
from .networks import * | |
from utils import * | |
from glob import glob | |
from .face_features import FaceFeatures | |
class UgatitSadalinHourglass(object): | |
def __init__(self, args): | |
self.light = args.light | |
if self.light: | |
self.model_name = 'UGATIT_light' | |
else: | |
self.model_name = 'UGATIT' | |
self.result_dir = args.result_dir | |
self.dataset = args.dataset | |
self.iteration = args.iteration | |
self.decay_flag = args.decay_flag | |
self.batch_size = args.batch_size | |
self.print_freq = args.print_freq | |
self.save_freq = args.save_freq | |
self.lr = args.lr | |
self.ch = args.ch | |
""" Weight """ | |
self.adv_weight = args.adv_weight | |
self.cycle_weight = args.cycle_weight | |
self.identity_weight = args.identity_weight | |
self.cam_weight = args.cam_weight | |
self.faceid_weight = args.faceid_weight | |
""" Discriminator """ | |
self.n_dis = args.n_dis | |
self.img_size = args.img_size | |
self.img_ch = args.img_ch | |
self.device = f'cuda:{args.gpu_ids[0]}' | |
self.gpu_ids = args.gpu_ids | |
self.benchmark_flag = args.benchmark_flag | |
self.resume = args.resume | |
self.rho_clipper = args.rho_clipper | |
self.w_clipper = args.w_clipper | |
self.pretrained_weights = args.pretrained_weights | |
if torch.backends.cudnn.enabled and self.benchmark_flag: | |
print('set benchmark !') | |
torch.backends.cudnn.benchmark = True | |
print("##### Information #####") | |
print("# light : ", self.light) | |
print("# dataset : ", self.dataset) | |
print("# batch_size : ", self.batch_size) | |
print("# iteration per epoch : ", self.iteration) | |
print("##### Discriminator #####") | |
print("# discriminator layer : ", self.n_dis) | |
print() | |
print("##### Weight #####") | |
print("# adv_weight : ", self.adv_weight) | |
print("# cycle_weight : ", self.cycle_weight) | |
print("# faceid_weight : ", self.faceid_weight) | |
print("# identity_weight : ", self.identity_weight) | |
print("# cam_weight : ", self.cam_weight) | |
print("# rho_clipper: ", self.rho_clipper) | |
print("# w_clipper: ", self.w_clipper) | |
################################################################################## | |
# Model | |
################################################################################## | |
def build_model(self): | |
""" DataLoader """ | |
train_transform = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.Resize((self.img_size + 30, self.img_size+30)), | |
transforms.RandomCrop(self.img_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
test_transform = transforms.Compose([ | |
transforms.Resize((self.img_size, self.img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform) | |
self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform) | |
self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform) | |
self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform) | |
self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True) | |
self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True) | |
self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False) | |
self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False) | |
""" Define Generator, Discriminator """ | |
self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) | |
self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) | |
self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) | |
self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) | |
self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) | |
self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) | |
self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device) | |
""" Define Loss """ | |
self.L1_loss = nn.L1Loss().to(self.device) | |
self.MSE_loss = nn.MSELoss().to(self.device) | |
self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) | |
""" Trainer """ | |
self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001) | |
self.D_optim = torch.optim.Adam( | |
itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()), | |
lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001 | |
) | |
""" Define Rho clipper to constraint the value of rho in AdaLIN and LIN""" | |
self.Rho_clipper = RhoClipper(0, self.rho_clipper) | |
self.W_Clipper = WClipper(0, self.w_clipper) | |
def train(self): | |
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() | |
start_iter = 1 | |
if self.resume: | |
model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) | |
if not len(model_list) == 0: | |
model_list.sort() | |
start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) | |
self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) | |
print(" [*] Load SUCCESS") | |
if self.decay_flag and start_iter > (self.iteration // 2): | |
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) | |
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) | |
if self.pretrained_weights: | |
params = torch.load(self.pretrained_weights, map_location=self.device) | |
self.genA2B.load_state_dict(params['genA2B']) | |
self.genB2A.load_state_dict(params['genB2A']) | |
self.disGA.load_state_dict(params['disGA']) | |
self.disGB.load_state_dict(params['disGB']) | |
self.disLA.load_state_dict(params['disLA']) | |
self.disLB.load_state_dict(params['disLB']) | |
print(" [*] Load {} Success".format(self.pretrained_weights)) | |
if len(self.gpu_ids) > 1: | |
self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids) | |
self.genB2A = nn.DataParallel(self.genB2A, device_ids=self.gpu_ids) | |
self.disGA = nn.DataParallel(self.disGA, device_ids=self.gpu_ids) | |
self.disGB = nn.DataParallel(self.disGB, device_ids=self.gpu_ids) | |
self.disLA = nn.DataParallel(self.disLA, device_ids=self.gpu_ids) | |
self.disLB = nn.DataParallel(self.disLB, device_ids=self.gpu_ids) | |
# training loop | |
print('training start !') | |
start_time = time.time() | |
for step in range(start_iter, self.iteration + 1): | |
if self.decay_flag and step > (self.iteration // 2): | |
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) | |
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) | |
try: | |
real_A, _ = trainA_iter.next() | |
except: | |
trainA_iter = iter(self.trainA_loader) | |
real_A, _ = trainA_iter.next() | |
try: | |
real_B, _ = trainB_iter.next() | |
except: | |
trainB_iter = iter(self.trainB_loader) | |
real_B, _ = trainB_iter.next() | |
real_A, real_B = real_A.to(self.device), real_B.to(self.device) | |
# Update D | |
self.D_optim.zero_grad() | |
fake_A2B, _, _ = self.genA2B(real_A) | |
fake_B2A, _, _ = self.genB2A(real_B) | |
real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) | |
real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) | |
real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) | |
real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) | |
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) | |
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) | |
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) | |
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) | |
D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + \ | |
self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) | |
D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + \ | |
self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) | |
D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + \ | |
self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) | |
D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) +\ | |
self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) | |
D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + \ | |
self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) | |
D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + \ | |
self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) | |
D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + \ | |
self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) | |
D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) +\ | |
self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) | |
D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) | |
D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) | |
Discriminator_loss = D_loss_A + D_loss_B | |
Discriminator_loss.backward() | |
self.D_optim.step() | |
# Update G | |
self.G_optim.zero_grad() | |
fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) | |
fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) | |
fake_A2B2A, _, _ = self.genB2A(fake_A2B) | |
fake_B2A2B, _, _ = self.genA2B(fake_B2A) | |
fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) | |
fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) | |
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) | |
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) | |
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) | |
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) | |
G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) | |
G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) | |
G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) | |
G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) | |
G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) | |
G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) | |
G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) | |
G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) | |
G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) | |
G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) | |
G_identity_loss_A = self.L1_loss(fake_A2A, real_A) | |
G_identity_loss_B = self.L1_loss(fake_B2B, real_B) | |
G_id_loss_A = self.facenet.cosine_distance(real_A, fake_A2B) | |
G_id_loss_B = self.facenet.cosine_distance(real_B, fake_B2A) | |
if len(self.gpu_ids) > 1: | |
G_id_loss_A = torch.mean(G_id_loss_A) | |
G_id_loss_B = torch.mean(G_id_loss_B) | |
G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + \ | |
self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device)) | |
G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + \ | |
self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device)) | |
G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \ | |
self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + \ | |
self.cam_weight * G_cam_loss_A + self.faceid_weight * G_id_loss_A | |
G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \ | |
self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + \ | |
self.cam_weight * G_cam_loss_B + self.faceid_weight * G_id_loss_B | |
Generator_loss = G_loss_A + G_loss_B | |
Generator_loss.backward() | |
self.G_optim.step() | |
# clip parameter of Soft-AdaLIN and LIN, applied after optimizer step | |
self.genA2B.apply(self.Rho_clipper) | |
self.genB2A.apply(self.Rho_clipper) | |
self.genA2B.apply(self.W_Clipper) | |
self.genB2A.apply(self.W_Clipper) | |
if step % 10 == 0: | |
print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) | |
if step % self.print_freq == 0: | |
train_sample_num = 5 | |
test_sample_num = 5 | |
A2B = np.zeros((self.img_size * 7, 0, 3)) | |
B2A = np.zeros((self.img_size * 7, 0, 3)) | |
self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() | |
with torch.no_grad(): | |
for _ in range(train_sample_num): | |
try: | |
real_A, _ = trainA_iter.next() | |
except: | |
trainA_iter = iter(self.trainA_loader) | |
real_A, _ = trainA_iter.next() | |
try: | |
real_B, _ = trainB_iter.next() | |
except: | |
trainB_iter = iter(self.trainB_loader) | |
real_B, _ = trainB_iter.next() | |
real_A, real_B = real_A.to(self.device), real_B.to(self.device) | |
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) | |
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) | |
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) | |
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) | |
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) | |
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) | |
A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), | |
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), | |
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), | |
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) | |
B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), | |
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), | |
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), | |
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) | |
for _ in range(test_sample_num): | |
try: | |
real_A, _ = testA_iter.next() | |
except: | |
testA_iter = iter(self.testA_loader) | |
real_A, _ = testA_iter.next() | |
try: | |
real_B, _ = testB_iter.next() | |
except: | |
testB_iter = iter(self.testB_loader) | |
real_B, _ = testB_iter.next() | |
real_A, real_B = real_A.to(self.device), real_B.to(self.device) | |
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) | |
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) | |
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) | |
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) | |
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) | |
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) | |
A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), | |
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), | |
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), | |
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) | |
B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), | |
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), | |
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), | |
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) | |
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) | |
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) | |
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() | |
if step % self.save_freq == 0: | |
self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) | |
if step % 1000 == 0: | |
params = {} | |
if len(self.gpu_ids) > 1: | |
params['genA2B'] = self.genA2B.module.state_dict() | |
params['genB2A'] = self.genB2A.module.state_dict() | |
params['disGA'] = self.disGA.module.state_dict() | |
params['disGB'] = self.disGB.module.state_dict() | |
params['disLA'] = self.disLA.module.state_dict() | |
params['disLB'] = self.disLB.module.state_dict() | |
else: | |
params['genA2B'] = self.genA2B.state_dict() | |
params['genB2A'] = self.genB2A.state_dict() | |
params['disGA'] = self.disGA.state_dict() | |
params['disGB'] = self.disGB.state_dict() | |
params['disLA'] = self.disLA.state_dict() | |
params['disLB'] = self.disLB.state_dict() | |
torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt')) | |
def save(self, dir, step): | |
params = {} | |
if len(self.gpu_ids) > 1: | |
params['genA2B'] = self.genA2B.module.state_dict() | |
params['genB2A'] = self.genB2A.module.state_dict() | |
params['disGA'] = self.disGA.module.state_dict() | |
params['disGB'] = self.disGB.module.state_dict() | |
params['disLA'] = self.disLA.module.state_dict() | |
params['disLB'] = self.disLB.module.state_dict() | |
else: | |
params['genA2B'] = self.genA2B.state_dict() | |
params['genB2A'] = self.genB2A.state_dict() | |
params['disGA'] = self.disGA.state_dict() | |
params['disGB'] = self.disGB.state_dict() | |
params['disLA'] = self.disLA.state_dict() | |
params['disLB'] = self.disLB.state_dict() | |
torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) | |
def load(self, dir, step): | |
params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) | |
self.genA2B.load_state_dict(params['genA2B']) | |
self.genB2A.load_state_dict(params['genB2A']) | |
self.disGA.load_state_dict(params['disGA']) | |
self.disGB.load_state_dict(params['disGB']) | |
self.disLA.load_state_dict(params['disLA']) | |
self.disLB.load_state_dict(params['disLB']) | |
def test(self): | |
model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) | |
if not len(model_list) == 0: | |
model_list.sort() | |
iter = int(model_list[-1].split('_')[-1].split('.')[0]) | |
self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) | |
print(" [*] Load SUCCESS") | |
else: | |
print(" [*] Load FAILURE") | |
return | |
self.genA2B.eval(), self.genB2A.eval() | |
with torch.no_grad(): | |
for n, (real_A, _) in enumerate(self.testA_loader): | |
real_A = real_A.to(self.device) | |
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) | |
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) | |
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) | |
A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), | |
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), | |
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), | |
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0) | |
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0) | |
for n, (real_B, _) in enumerate(self.testB_loader): | |
real_B = real_B.to(self.device) | |
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) | |
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) | |
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) | |
B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), | |
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), | |
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), | |
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), | |
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0) | |
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0) | |