Paul Engstler
Initial commit
92f0e98
import matplotlib
matplotlib.use('Agg')
import os
import wandb
import pytorch_lightning as pl
from data import VerseDataModule
from model import VerseFxClassifier
from utils.config import get_config
if __name__ == '__main__':
config = get_config("config.yaml")
USE_WANDB = 'online' if config.pop('USE_WANDB', False) else 'disabled'
WANDB_API_KEY = config.pop('WANDB_API_KEY')
SAVE_MODEL = config.pop('SAVE_MODEL')
wandb.login(key=WANDB_API_KEY)
run = wandb.init(
project=f'fx-{config["task"]}-baseline-3d',
entity='ifl-diva',
config=config,
mode=USE_WANDB
)
hparams = wandb.config
wandb_logger = pl.loggers.WandbLogger()
model = VerseFxClassifier(hparams)
data = VerseDataModule(hparams)
callbacks = [pl.callbacks.EarlyStopping(monitor="val/F1", mode="max", patience=hparams.early_stopping_patience)]
if bool(SAVE_MODEL):
callbacks.append(pl.callbacks.model_checkpoint.ModelCheckpoint(monitor='val/F1', mode="max",
dirpath='saved_models',
filename=f"{wandb.run.name}-epoch{{epoch}}-val_F1={{val/F1:.3f}}",
auto_insert_metric_name=False))
trainer = pl.Trainer(
gpus=1,
logger=wandb_logger,
log_every_n_steps=2,
#max_epochs=2,
callbacks=callbacks,
# auto_lr_find=hparams.auto_lr_find,
)
with run:
trainer.fit(model, data)