File size: 1,526 Bytes
55d9b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
io:
  # I/O
  out_dir : "debug"
  eval_interval : 10
  log_interval : 10
  eval_iters : 5
  eval_only : false  # if True, script exits right after the first eval
  always_save_checkpoint : true  # if True, always save a checkpoint after each eval
  init_from : "scratch"  # 'scratch' or 'resume'
  resume_when_snapshot_available: false

loader:
  batch_size : 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
  max_seq_len : 768
  dataset : "smiles" 
  processed_dataset_ckpt : "processed_dataset_500000.pkl"
  fragment_creator : "rss"

model:
  dim : 32
  n_layers : 1
  n_heads : 1
  multiple_of : 16
  dropout : 0.1

context:
  context_keys: ["logp", "sascore", "mol_weight"]
  context_dims : [1,1,1]

optimizer:
  gradient_accumulation_steps : 4  # used to simulate larger batch sizes
  learning_rate : 1e-4  # max learning rate
  max_iters : 20  # total number of training iterations
  weight_decay : 1e-1
  beta1 : 0.9
  beta2 : 0.95
  grad_clip : 1.0  # clip gradients at this value, or disable if == 0.0
  # learning rate decay settings
  decay_lr : true  # whether to decay the learning rate
  warmup_iters : 10  # how many steps to warm up for
  lr_decay_iters : 100  # should be ~= max_iters per Chinchilla
  min_lr : 0.0  # minimum learning rate, should be ~= learning_rate/10 per Chinchilla

dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
compile : false # Use torch.compile, but in my test this is really slow
label : "llama2-Debug"
profile : false