|
from trainer import ( |
|
IOConfig, |
|
LoaderConfig, |
|
Trainer, |
|
TrainerArgs, |
|
ModelArgs, |
|
ContextArgs, |
|
OptimizerConfig, |
|
) |
|
from torch.distributed.elastic.multiprocessing.errors import record |
|
|
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
import logging |
|
import sys |
|
import os |
|
import torch |
|
|
|
|
|
def setup_logger(run_name: str, log_path: str): |
|
ddp = int(os.environ.get("RANK", -1)) != -1 |
|
if ddp: |
|
ddp_rank = int(os.environ["RANK"]) |
|
ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
|
ddp_world_size = int(os.environ["WORLD_SIZE"]) |
|
|
|
formatter = logging.Formatter( |
|
f"[%(levelname)s] DDP[{ddp_rank},{ddp_local_rank},{ddp_world_size}] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
else: |
|
formatter = logging.Formatter( |
|
r"[%(levelname)s] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
stream_handler.setFormatter(formatter) |
|
|
|
os.makedirs(log_path, exist_ok=True) |
|
file_handler = logging.FileHandler(os.path.join(log_path, f"train_{run_name}.log")) |
|
file_handler.setFormatter(formatter) |
|
|
|
logging.basicConfig(level=logging.INFO, handlers=[stream_handler, file_handler]) |
|
|
|
return logging.getLogger() |
|
|
|
|
|
@record |
|
@hydra.main(version_base=None, config_path="config", config_name="config") |
|
def main(cfg: DictConfig) -> None: |
|
logger = setup_logger( |
|
cfg.get("run_name", "default"), cfg.get("io", {"out_dir": "out"})["out_dir"] |
|
) |
|
|
|
logger.info("Using config") |
|
logger.info(cfg) |
|
|
|
cfg = cfg["train"] |
|
io_conf = IOConfig(**cfg.get("io", {})) |
|
loader_conf = LoaderConfig(**cfg.get("loader", {})) |
|
model_args = ModelArgs(**cfg.get("model", {})) |
|
ctx_args = ContextArgs(**cfg.get("context", {})) |
|
optmizer_conf = OptimizerConfig(**cfg.get("optimizer", {})) |
|
train_args = TrainerArgs( |
|
io_conf=io_conf, |
|
loader_conf=loader_conf, |
|
model_conf=model_args, |
|
context_conf=ctx_args, |
|
optimizer_conf=optmizer_conf, |
|
run_name=cfg.get("label", "train_run"), |
|
) |
|
|
|
|
|
torch.set_num_threads(8) |
|
|
|
trainer = Trainer( |
|
train_args=train_args, |
|
dtype=cfg.get("dtype", "float16"), |
|
compile=cfg.get("compile", False), |
|
) |
|
should_profile = cfg.get("profile", False) |
|
|
|
if should_profile: |
|
with torch.profiler.profile( |
|
activities=[ |
|
torch.profiler.ProfilerActivity.CPU, |
|
torch.profiler.ProfilerActivity.CUDA, |
|
] |
|
) as p: |
|
trainer.train() |
|
|
|
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) |
|
|
|
else: |
|
trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|