# Demonstrator

### Load the model

In [5]:
import rdkit
from rdkit import Chem
import rdkit.rdBase as rkrb
import rdkit.RDLogger as rkl
import os
import torch 
import logging
import numpy as np
from plot_utils import check_metrics
from sample import Sampler
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if "cuda" in device:
    # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
    dtype = "float16" if torch.cuda.is_available() else "float32"
else:
    dtype = "float32"

logger = rkl.logger()
logger.setLevel(rkl.ERROR)
rkrb.DisableLog("rdApp.error")

torch.set_num_threads(8)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

sampler = Sampler(
    load_path=os.path.join(
        os.getcwd(), "out", "llama2-M-Full-RSS.pt"
    ),
    device=device,
    seed=1234,
    dtype=dtype,
    compile=True,
)

    
num_samples = 100
df_comp = pd.read_parquet(os.path.join(os.getcwd(),"data","OrganiX13.parquet"))
df_comp = df_comp.sample(n=2_500_000)
comp_context_dict = {c: df_comp[c].to_numpy() for c in ["logp", "sascore", "mol_weight"]} 
comp_smiles = df_comp["smiles"]



INFO:sample:Compiling the model...


In [6]:
from typing import List, Dict
import json
from rdkit.Chem import AllChem

@torch.no_grad()
def convert_to_chemiscope(smiles_list : List[str], context_dict : Dict[str, List[float]]):
    # For more details on the file format: https://chemiscope.org/docs/tutorial/input-reference.html

    structures = []
    remove_list = []
    for i,smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            logging.info(f"Mol invalid: {smi} ! Skipping...")
            remove_list.append(i)
            continue

        res = AllChem.EmbedMolecule(mol,randomSeed=0xf00d, maxAttempts=20)
        # res = AllChem.Compute2DCoords(mol)

        if res != 0:
            logging.info(f"Could not calculate coordinates for {smi}! Skipping..")
            remove_list.append(i)
            continue
        

        conf = list(mol.GetConformers())[0]
        x,y,z = [],[],[]
        symbols = []
        for atom, coords in zip(mol.GetAtoms(), conf.GetPositions()):
            symbols.append(atom.GetSymbol())
            x.append(coords[0])
            y.append(coords[1])
            z.append(coords[2])
        
        structures.append({
            "size": len(x),
            "names": symbols,
            "x": x,
            "y": y,
            "z" : z
        })



    properties = {}
    
    for c in context_dict:
        properties[c] = {
            "target": "structure",
            "values": [v for i, v in enumerate(context_dict[c]) if i not in remove_list]
        }
        


    
    data = {
        "meta": {
            # // the name of the dataset
            "name": "Test Dataset",
            # // description of the dataset, OPTIONAL
            "description": "This contains data from generated molecules",
            # // authors of the dataset, OPTIONAL
            "authors": ["Niklas Dobberstein, niklas.dobberstein@scai.fraunhofer.de"],
            # // references for the dataset, OPTIONAL
            "references": [
                "",
            ],
        
        },
        "properties": properties,
        "structures": structures
    }
    
    out_path = os.path.join(os.getcwd(), "chemiscope_gen.json")
    with open(out_path, "w") as f:
        json.dump(data, f)

    logging.info(f"Wrote file {out_path}")

convert_to_chemiscope([
    "CC=O",
    "s1ccnc1"
], {"logp": [1.0,2.0], "sascore": [1.5,-2.0]})

INFO:root:Wrote file /home/ndobberstein/Projekte/llama2-molgen/chemiscope_gen.json


In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import numpy as np
import torch
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
import logging
from plot_utils import calc_context_from_smiles

# Define the context_cols options and create checkboxes for them
context_cols_options = ["logp", "sascore", "mol_weight"]
context_cols_checkboxes = [widgets.Checkbox(description=col, value=False) for col in context_cols_options]

# Create a text input for context_smi
context_smi_input = widgets.Text(description="Context SMI:", value="")

# Create sliders for temperature and context_cols values
temperature_slider = widgets.FloatSlider(description="Temperature:", min=0, max=2.0, step=0.1, value=0.8)

logp_slider = widgets.FloatSlider(description="logp:", min=-4, max=7, step=0.5, value=0.0)
sascore_slider = widgets.FloatSlider(description="sascore:", min=1, max=10, step=0.5, value=2.0)
mol_weight_slider = widgets.FloatSlider(description="mol_weight:", min=0.5, max=10, step=0.5, value=3.0)

# Create a button to generate the code and display SMILES
generate_button = widgets.Button(description="Generate")

# Create an output widget for displaying generated information
output = widgets.Output()

# Create an output widget for displaying the RDKit molecules
molecule_output = widgets.Output()

@torch.no_grad()
def generate_code(_):
    with output:
        clear_output(wait=False)
        # logging.info("Parameters used in generation:")
        
        # Get the selected context_cols
        selected_context_cols = [col for col, checkbox in zip(context_cols_options, context_cols_checkboxes) if checkbox.value]
        # logging.info(f"Context Cols: {selected_context_cols}")
        
        # Get the values of context_smi and temperature from the sliders
        context_smi = context_smi_input.value.strip()
        temperature = temperature_slider.value
        # logging.info(f"Context Smiles: {context_smi}")
        # logging.info(f"Temperature: {temperature}")
        
        # Get the values of logp, sascore, and mol_weight from the sliders
        context_dict = {} if len(selected_context_cols) != 0 else None
        for c in selected_context_cols:
            if c == "logp":
                val = logp_slider.value
            elif c == "sascore":
                val = sascore_slider.value
            else:
                val = mol_weight_slider.value
            val = round(val, 2)
            context_dict[c] = val*torch.ones((num_samples,),device=device,dtype=torch.float)
            # logging.info(f"{c}: {val}")
        
        # Generate SMILES using the provided context
        smiles, context = sampler.generate(
            context_cols=context_dict,
            context_smi=context_smi,
            start_smiles=None,
            num_samples=num_samples,
            max_new_tokens=256,
            temperature=temperature,
            top_k=25,
            total_gen_steps=int(np.ceil(num_samples / 1000)),
            return_context=True
        )
        
        with open(os.path.join(os.getcwd(), "gen_smiles.txt"), "w") as f:
            for s in smiles:
                f.write(f"{s}\n")
        # Display SMILES as RDKit molecules
        display_molecules(smiles, context)



def display_molecules(smiles_list, context_dict):
    with molecule_output:
        clear_output(wait=False)
        molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]
        
        # Convert RDKit molecules to images and store them in a list
        images = [Draw.MolToImage(mol) for mol in molecules]
        
        # Create a subplot grid to display the images
        num_images = len(images)
        num_cols = 5  # Number of columns in the grid
        num_rows = (num_images + num_cols - 1) // num_cols  # Calculate the number of rows
        
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 25))
        fig.subplots_adjust(hspace=0.5)
        calculated_context = {c:[] for c in context_dict}
        for i, ax in enumerate(axes.flat):
            if i < num_images:
                ax.imshow(images[i])
                for j, c in enumerate(context_dict):
                    smiles = smiles_list[i]
                    smi_con = round(calc_context_from_smiles([smiles], c)[0],2)
                    calculated_context[c].append(smi_con)
                    ax.text(0.5, -0.1 * j , f"{c}: {context_dict[c][i]} vs {smi_con}", transform=ax.transAxes, fontsize=10, ha='center')
                
                ax.axis('off')
            else:
                fig.delaxes(ax)  # Remove empty subplots if there are more rows than images
        

        if len(context_dict) >= 2:
            convert_to_chemiscope(smiles_list, calculated_context)

        plt.savefig("gen_mols.png")
        plt.show()

# Attach the generate_code function to the button's click event
generate_button.on_click(generate_code)

# Display the widgets
display(widgets.HBox(context_cols_checkboxes))
display(widgets.HBox((logp_slider, sascore_slider, mol_weight_slider)))

display(context_smi_input)
display(temperature_slider)
display(generate_button)
display(output)
display(molecule_output)

HBox(children=(Checkbox(value=False, description='logp'), Checkbox(value=False, description='sascore'), Checkb…

HBox(children=(FloatSlider(value=0.0, description='logp:', max=7.0, min=-4.0, step=0.5), FloatSlider(value=2.0…

Text(value='', description='Context SMI:')

FloatSlider(value=0.8, description='Temperature:', max=2.0)

Button(description='Generate', style=ButtonStyle())

Output()

Output()

In [None]:
selected_context_cols = ["logp", "sascore", "mol_weight"]
num_samples = 25
context_dict = {} if len(selected_context_cols) != 0 else None
for c in selected_context_cols:
    if c == "logp":
        v = 0.5 * torch.randint(
            -8, 14, (num_samples,), device=device, dtype=torch.float
        )
        context_dict[c] = v.sort()[0]
    elif c == "sascore":
        v = 0.5 * torch.randint(
            1, 20, (num_samples,), device=device, dtype=torch.float
        )
        context_dict[c] = v.sort()[0]
    else:
        v = 0.5 * torch.randint(
            1, 20, (num_samples,), device=device, dtype=torch.float
        )
        
        context_dict[c] = v.sort()[0]
    # logging.info(f"{c}: {val}")

# Generate SMILES using the provided context
smiles, context = sampler.generate(
    context_cols=context_dict,
    context_smi=None,
    start_smiles=None,
    num_samples=num_samples,
    max_new_tokens=256,
    temperature=0.8,
    top_k=25,
    total_gen_steps=int(np.ceil(num_samples / 1000)),
    return_context=True
)

# Display SMILES as RDKit molecules
display_molecules(smiles, context)


Batch:   0%|          | 0/1 [00:00<?, ?it/s]

Generation:   0%|          | 0/256 [00:00<?, ?it/s]

INFO:sample:Number valid generated: 68.0 %
INFO:sample:---------------
