|
import argparse |
|
import contextlib |
|
import importlib |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
import traceback |
|
import pytorch_lightning as pl |
|
import torch |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
|
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only |
|
import craftsman |
|
from craftsman.systems.base import BaseSystem |
|
from craftsman.utils.callbacks import ( |
|
CodeSnapshotCallback, |
|
ConfigSnapshotCallback, |
|
CustomProgressBar, |
|
ProgressCallback, |
|
) |
|
from craftsman.utils.config import ExperimentConfig, load_config |
|
from craftsman.utils.misc import get_rank |
|
from craftsman.utils.typing import Optional |
|
class ColoredFilter(logging.Filter): |
|
""" |
|
A logging filter to add color to certain log levels. |
|
""" |
|
|
|
RESET = "\033[0m" |
|
RED = "\033[31m" |
|
GREEN = "\033[32m" |
|
YELLOW = "\033[33m" |
|
BLUE = "\033[34m" |
|
MAGENTA = "\033[35m" |
|
CYAN = "\033[36m" |
|
|
|
COLORS = { |
|
"WARNING": YELLOW, |
|
"INFO": GREEN, |
|
"DEBUG": BLUE, |
|
"CRITICAL": MAGENTA, |
|
"ERROR": RED, |
|
} |
|
|
|
RESET = "\x1b[0m" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def filter(self, record): |
|
if record.levelname in self.COLORS: |
|
color_start = self.COLORS[record.levelname] |
|
record.levelname = f"{color_start}[{record.levelname}]" |
|
record.msg = f"{record.msg}{self.RESET}" |
|
return True |
|
|
|
|
|
def load_custom_module(module_path): |
|
module_name = os.path.basename(module_path) |
|
if os.path.isfile(module_path): |
|
sp = os.path.splitext(module_path) |
|
module_name = sp[0] |
|
try: |
|
if os.path.isfile(module_path): |
|
module_spec = importlib.util.spec_from_file_location( |
|
module_name, module_path |
|
) |
|
else: |
|
module_spec = importlib.util.spec_from_file_location( |
|
module_name, os.path.join(module_path, "__init__.py") |
|
) |
|
|
|
module = importlib.util.module_from_spec(module_spec) |
|
sys.modules[module_name] = module |
|
module_spec.loader.exec_module(module) |
|
return True |
|
except Exception as e: |
|
print(traceback.format_exc()) |
|
print(f"Cannot import {module_path} module for custom nodes:", e) |
|
return False |
|
|
|
|
|
def load_custom_modules(): |
|
node_paths = ["custom"] |
|
node_import_times = [] |
|
if not os.path.exists("node_paths"): |
|
return |
|
for custom_node_path in node_paths: |
|
possible_modules = os.listdir(custom_node_path) |
|
if "__pycache__" in possible_modules: |
|
possible_modules.remove("__pycache__") |
|
|
|
for possible_module in possible_modules: |
|
module_path = os.path.join(custom_node_path, possible_module) |
|
if ( |
|
os.path.isfile(module_path) |
|
and os.path.splitext(module_path)[1] != ".py" |
|
): |
|
continue |
|
if module_path.endswith(".disabled"): |
|
continue |
|
time_before = time.perf_counter() |
|
success = load_custom_module(module_path) |
|
node_import_times.append( |
|
(time.perf_counter() - time_before, module_path, success) |
|
) |
|
|
|
if len(node_import_times) > 0: |
|
print("\nImport times for custom modules:") |
|
for n in sorted(node_import_times): |
|
if n[2]: |
|
import_message = "" |
|
else: |
|
import_message = " (IMPORT FAILED)" |
|
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) |
|
print() |
|
|
|
|
|
def main(args, extras) -> None: |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) |
|
env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] |
|
selected_gpus = [0] |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
|
|
devices = -1 |
|
if len(env_gpus) > 0: |
|
n_gpus = len(env_gpus) |
|
else: |
|
selected_gpus = list(args.gpu.split(",")) |
|
n_gpus = len(selected_gpus) |
|
print(f"Using {n_gpus} GPUs: {selected_gpus}") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
|
if args.typecheck: |
|
from jaxtyping import install_import_hook |
|
|
|
install_import_hook("craftsman", "typeguard.typechecked") |
|
|
|
logger = logging.getLogger("pytorch_lightning") |
|
if args.verbose: |
|
logger.setLevel(logging.DEBUG) |
|
|
|
for handler in logger.handlers: |
|
if handler.stream == sys.stderr: |
|
if not args.gradio: |
|
handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) |
|
handler.addFilter(ColoredFilter()) |
|
else: |
|
handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) |
|
|
|
load_custom_modules() |
|
|
|
|
|
cfg: ExperimentConfig |
|
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) |
|
|
|
|
|
pl.seed_everything(cfg.seed + get_rank(), workers=True) |
|
|
|
dm = craftsman.find(cfg.data_type)(cfg.data) |
|
system: BaseSystem = craftsman.find(cfg.system_type)( |
|
cfg.system, resumed=cfg.resume is not None |
|
) |
|
system.set_save_dir(os.path.join(cfg.trial_dir, "save")) |
|
|
|
if args.gradio: |
|
fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) |
|
fh.setLevel(logging.INFO) |
|
if args.verbose: |
|
fh.setLevel(logging.DEBUG) |
|
fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) |
|
logger.addHandler(fh) |
|
|
|
callbacks = [] |
|
if args.train: |
|
callbacks += [ |
|
ModelCheckpoint( |
|
dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint |
|
), |
|
LearningRateMonitor(logging_interval="step"), |
|
CodeSnapshotCallback( |
|
os.path.join(cfg.trial_dir, "code"), use_version=False |
|
), |
|
ConfigSnapshotCallback( |
|
args.config, |
|
cfg, |
|
os.path.join(cfg.trial_dir, "configs"), |
|
use_version=False, |
|
), |
|
] |
|
if args.gradio: |
|
callbacks += [ |
|
ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) |
|
] |
|
else: |
|
callbacks += [CustomProgressBar(refresh_rate=1)] |
|
|
|
def write_to_text(file, lines): |
|
with open(file, "w") as f: |
|
for line in lines: |
|
f.write(line + "\n") |
|
|
|
loggers = [] |
|
if args.train: |
|
|
|
rank_zero_only( |
|
lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) |
|
)() |
|
loggers += [ |
|
TensorBoardLogger(cfg.trial_dir, name="tb_logs"), |
|
CSVLogger(cfg.trial_dir, name="csv_logs"), |
|
] + system.get_loggers() |
|
rank_zero_only( |
|
lambda: write_to_text( |
|
os.path.join(cfg.trial_dir, "cmd.txt"), |
|
["python " + " ".join(sys.argv), str(args)], |
|
) |
|
)() |
|
|
|
trainer = Trainer( |
|
callbacks=callbacks, |
|
logger=loggers, |
|
inference_mode=False, |
|
accelerator="gpu", |
|
devices=devices, |
|
|
|
**cfg.trainer, |
|
) |
|
|
|
def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): |
|
if ckpt_path is None: |
|
return |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) |
|
if args.train: |
|
trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) |
|
trainer.test(system, datamodule=dm) |
|
if args.gradio: |
|
|
|
trainer.predict(system, datamodule=dm) |
|
elif args.validate: |
|
|
|
set_system_status(system, cfg.resume) |
|
trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) |
|
elif args.test: |
|
|
|
set_system_status(system, cfg.resume) |
|
trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) |
|
elif args.export: |
|
set_system_status(system, cfg.resume) |
|
trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", required=True, help="path to config file") |
|
parser.add_argument( |
|
"--gpu", |
|
default="0", |
|
help="GPU(s) to be used. 0 means use the 1st available GPU. " |
|
"1,2 means use the 2nd and 3rd available GPU. " |
|
"If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " |
|
"this argument is ignored and all available GPUs are always used.", |
|
) |
|
|
|
group = parser.add_mutually_exclusive_group(required=True) |
|
group.add_argument("--train", action="store_true") |
|
group.add_argument("--validate", action="store_true") |
|
group.add_argument("--test", action="store_true") |
|
group.add_argument("--export", action="store_true") |
|
|
|
parser.add_argument( |
|
"--gradio", action="store_true", help="if true, run in gradio mode" |
|
) |
|
|
|
parser.add_argument( |
|
"--verbose", action="store_true", help="if true, set logging level to DEBUG" |
|
) |
|
|
|
parser.add_argument( |
|
"--typecheck", |
|
action="store_true", |
|
help="whether to enable dynamic type checking", |
|
) |
|
|
|
args, extras = parser.parse_known_args() |
|
|
|
if args.gradio: |
|
with contextlib.redirect_stdout(sys.stderr): |
|
main(args, extras) |
|
else: |
|
main(args, extras) |
|
|