io: # I/O out_dir : "out" eval_interval : 500 log_interval : 10 eval_iters : 10 eval_only : false # if True, script exits right after the first eval always_save_checkpoint : false # if True, always save a checkpoint after each eval init_from : "scratch" # 'scratch' or 'resume' resume_when_snapshot_available: true loader: batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size max_seq_len : 768 dataset : "smiles" processed_dataset_ckpt : "processed_dataset_None.pkl" fragment_creator : "bricks" model: dim : 256 n_layers : 8 n_heads : 8 multiple_of : 128 dropout : 0.1 context: context_keys: ["logp", "sascore", "mol_weight"] context_dims : [1,1,1] optimizer: gradient_accumulation_steps : 4 # used to simulate larger batch sizes learning_rate : 1e-4 # max learning rate max_iters : 100000 # total number of training iterations weight_decay : 1e-1 beta1 : 0.9 beta2 : 0.95 grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0 # learning rate decay settings decay_lr : true # whether to decay the learning rate warmup_iters : 1000 # how many steps to warm up for lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16 compile : false # Use torch.compile, but in my test this is really slow label : "llama2-M-Full-BRICKS"