Spaces:
Runtime error
Runtime error
File size: 7,201 Bytes
b442155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
import os
import sys
import argparse
from typing import Optional
from datetime import datetime
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dalle.models import ImageGPT
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--config-downstream', type=str, default=None, required=True)
parser.add_argument('-u', '--path-upstream', type=str, default=None, required=True)
parser.add_argument('-r', '--result-path', type=str, default=None, required=True)
parser.add_argument('--imagenet-path', type=str, default=None, required=True)
parser.add_argument('--n-gpus', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
class ImageLogger(Callback):
def __init__(self):
super().__init__()
@rank_zero_only
def log_img(self, pl_module, batch, current_epoch, split="train"):
with torch.no_grad():
images, labels = batch
recons = pl_module.stage1(images)
images = images.cpu()
recons = recons.cpu()
grid_org = (torchvision.utils.make_grid(images, nrow=8) + 1.0) / 2.0
grid_rec = (torchvision.utils.make_grid(recons, nrow=8) + 1.0) / 2.0
grid_rec = torch.clip(grid_rec, min=0, max=1)
pl_module.logger.experiment.add_image(f"images_org/{split}", grid_org, global_step=current_epoch)
pl_module.logger.experiment.add_image(f"images_rec/{split}", grid_rec, global_step=current_epoch)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if batch_idx == 0 and trainer.current_epoch < 5:
self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="train")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if batch_idx == 0 and trainer.current_epoch < 5:
self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="test")
class ImageNetDataModule(pl.LightningDataModule):
def __init__(self,
data_dir: Optional[str] = None,
image_resolution: int = 256,
train_batch_size: int = 2,
valid_batch_size: int = 32,
num_workers: int = 8):
super().__init__()
self.data_dir = data_dir
self.image_resolution = image_resolution
self.train_batch_size = train_batch_size
self.valid_batch_size = valid_batch_size
self.num_workers = num_workers
self.train_transform = transforms.Compose(
[transforms.Resize(image_resolution),
transforms.RandomCrop(image_resolution),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)
self.valid_transform = transforms.Compose(
[transforms.Resize(image_resolution),
transforms.CenterCrop(image_resolution),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)
def setup(self, stage=None):
self.trainset = torchvision.datasets.ImageNet(root=self.data_dir, split='train', transform=self.train_transform)
self.validset = torchvision.datasets.ImageNet(root=self.data_dir, split='val', transform=self.valid_transform)
def train_dataloader(self):
return DataLoader(self.trainset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
pin_memory=True)
def valid_dataloader(self):
return DataLoader(self.validset,
batch_size=self.valid_batch_size,
num_workers=self.num_workers,
pin_memory=True)
def setup_callbacks(config):
# Setup callbacks
now = datetime.now().strftime('%d%m%Y_%H%M%S')
result_path = os.path.join(args.result_path,
os.path.basename(args.config_downstream).split('.')[0],
now)
ckpt_path = os.path.join(result_path, 'ckpt')
log_path = os.path.join(result_path, 'log')
checkpoint_callback = ModelCheckpoint(
dirpath=ckpt_path,
filename="imagenet-clscond-gen-{epoch:02d}" if config.stage2.use_cls_cond else
"imagenet-uncond-gen-{epoch:02d}",
every_n_epochs=config.experiment.save_ckpt_freq,
save_weights_only=True,
save_last=True
)
logger = TensorBoardLogger(log_path, name="iGPT")
logger_img = ImageLogger()
return checkpoint_callback, logger, logger_img
if __name__ == '__main__':
pl.seed_everything(args.seed)
# Build iGPT
model, config = ImageGPT.from_pretrained(args.path_upstream, args.config_downstream)
# Setup callbacks
ckpt_callback, logger, logger_img = setup_callbacks(config)
# Build data modules
dataset = ImageNetDataModule(data_dir=args.imagenet_path,
image_resolution=config.dataset.image_resolution,
train_batch_size=config.experiment.local_batch_size,
valid_batch_size=config.experiment.valid_batch_size,
num_workers=16)
dataset.setup()
train_dataloader = dataset.train_dataloader()
valid_dataloader = dataset.valid_dataloader()
print(f"len(train_dataset) = {len(dataset.trainset)}")
print(f"len(valid_dataset) = {len(dataset.validset)}")
# Calculate how many batches are accumulated
assert config.experiment.total_batch_size % (config.experiment.local_batch_size * args.n_gpus) == 0
grad_accm_steps = config.experiment.total_batch_size // (config.experiment.local_batch_size * args.n_gpus)
config.optimizer.max_steps = len(dataset.trainset) // config.experiment.total_batch_size * config.experiment.epochs
# Build trainer
trainer = pl.Trainer(max_epochs=config.experiment.epochs,
accumulate_grad_batches=grad_accm_steps,
gradient_clip_val=config.optimizer.grad_clip_norm,
precision=16 if config.experiment.use_amp else 32,
callbacks=[ckpt_callback, logger_img],
accelerator="gpu",
devices=args.n_gpus,
strategy="ddp",
logger=logger)
trainer.fit(model, train_dataloader, valid_dataloader)
|