from typing import Dict, List, Union import matplotlib.pyplot as plt import seaborn as sns import os import numpy as np import pandas as pd from rdkit.Chem import AllChem, Descriptors, RDConfig import sys sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score")) # now you can import sascore! import sascorer from rdkit import Chem import logging logger = logging.getLogger(__name__) # plt.rcParams.update({'font.size': 13.1}) plt.rcParams.update({"font.size": 12.5}) COL_TO_DISPLAY_NAME = { "logp": "LogP", "sascore": "SAScore", "mol_weight": "Molecular Weight", } def calcContextSAScore(smiles: List[str]): sasc = [] for smi in smiles: mol = Chem.MolFromSmiles(smi) sa = sascorer.calculateScore(mol) sasc.append(sa) return np.array(sasc) def calcContextLogP(smiles: List[str]): logps = [] for smi in smiles: mol = Chem.MolFromSmiles(smi) logp = Descriptors.MolLogP(mol) logps.append(logp) return np.array(logps) def calcContextEnergy(smiles, num_confs=5): contexts = [] for smi in smiles: # print("Calculating Energy:",smi) mol = Chem.AddHs(Chem.MolFromSmiles(smi)) AllChem.EmbedMultipleConfs(mol, num_confs, numThreads=48) generated_smiles = AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=48) energies = [] for coverged, energy in generated_smiles: if coverged != 0: print("Not converged!", smi) energies.append(energy) # print(energy) # kcal/mol mean_en = np.mean(energies) # to hartree mean_en = mean_en * 0.0016 contexts.append(mean_en) return np.array(contexts) def calcContextMolWeight(smiles: List[str]): con = [] for _, smi in enumerate(smiles): mol = Chem.MolFromSmiles(smi) c = Descriptors.ExactMolWt(mol) / 100 con.append(c) return np.array(con) def plot_1D_condition( context_col, save_path, new_context, generated_smiles, temperature, context_dict, context_scaler=None, ): for con_col in context_col: save_path = os.path.join( save_path, f"{con_col}_{'-'.join(context_col)}_temp{temperature}" ) os.makedirs(save_path, exist_ok=True) current_context = new_context[con_col].cpu().detach().numpy() if con_col == "mol_weight": predicted_context = calcContextMolWeight(generated_smiles) elif con_col == "logp": predicted_context = calcContextLogP(generated_smiles) elif con_col == "sascore": predicted_context = calcContextSAScore(generated_smiles) elif con_col == "energy": # TODO: Change to something better predicted_context = calcContextEnergy(generated_smiles) if context_scaler is not None: raise NotImplementedError("Not implemented yet") # context_list = context_scaler.inverse_transform(context_list) mean_vals_pred = [] labels = np.unique(current_context) mse_value = [] mad_value = [] for label in labels: mask = (current_context == label).reshape(-1) mean_val = np.mean(predicted_context[mask]) mean_vals_pred.append(mean_val) mse_value.extend((predicted_context[mask] - label) ** 2) mad_value.extend(abs(predicted_context[mask] - label)) mse = np.mean(mse_value) mad = np.mean(mad_value) logger.info(f"MSE {mse}") logger.info(f"MAD {mad}") logger.info(f"SD: {np.std(mad_value)}") current_context = current_context.reshape(-1) # Create a figure and axes fig, ax1 = plt.subplots() # Scatter plot ax1.scatter( current_context, predicted_context, label="Ground Truth vs Prediction", c="blue", alpha=0.5, ) ax1.plot( np.arange(np.min(current_context), np.max(current_context) + 1), np.arange(np.min(current_context), np.max(current_context) + 1), label="y=x", c="black", ) ax1.scatter(labels, mean_vals_pred, label="Mean predicted values", c="red") ax1.set_xlabel("Ground Truth") ax1.set_ylabel("Prediction") # Histogram ax2 = ax1.twinx() # Create a twin Axes sharing the x-axis sns.histplot( context_dict[con_col], # bins=200, label="Dataset distribution", alpha=0.5, # kde=True, # element="poly", ax=ax2, ) # ax2.hist( # context_dict[con_col], # bins=200, # label="Dataset distribution", # alpha=0.5, # ) ax2.set_ylabel("Frequency") # Combine legends handles1, labels1 = ax1.get_legend_handles_labels() handles2, labels2 = ax2.get_legend_handles_labels() ax1.legend(handles1 + handles2, labels1 + labels2) plt.xlim((np.min(current_context), np.max(current_context) + 1)) # Set title display_name = COL_TO_DISPLAY_NAME[con_col] plt.title(f"{display_name} - temperature: {temperature} - mse: {round(mse, 4)}") out_df = pd.DataFrame( { "smiles": generated_smiles, f"{con_col}": predicted_context.tolist(), f"target_{con_col}": current_context.tolist(), } ) out_df.to_csv(os.path.join(save_path, "predictions.csv"), index=False) out_path = os.path.join(save_path, "graph.png") print(f"Saved to {out_path}") plt.savefig(out_path) plt.clf() def plot_2D_condition( context_col, save_path, new_context, generated_smiles, temperature, label: Union[str, None] = None, ): save_path = os.path.join( save_path, f"multicond2_{'-'.join(context_col)}_temp={temperature}" ) if label is not None: save_path = os.path.join(save_path, label) os.makedirs(save_path, exist_ok=True) delta_dict = {c: [] for c in context_col} predicted_context_dict = {} for con_col in context_col: current_context = new_context[con_col].cpu().numpy() if con_col == "mol_weight": predicted_context = calcContextMolWeight(generated_smiles) elif con_col == "logp": predicted_context = calcContextLogP(generated_smiles) elif con_col == "sascore": predicted_context = calcContextSAScore(generated_smiles) elif con_col == "energy": # TODO: Change to something better predicted_context = calcContextEnergy(generated_smiles) predicted_context_dict[con_col] = np.array(predicted_context) delta_dict[con_col] = np.abs(current_context - np.array(predicted_context)) # Create a DataFrame from delta_dict df = pd.DataFrame(delta_dict) real_values_prop1 = new_context[context_col[0]].cpu().numpy() real_values_prop2 = new_context[context_col[1]].cpu().numpy() # cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib mse_vals_x = [] mad_vals_x = [] mse_vals_y = [] mad_vals_y = [] fig = plt.figure() ax = plt.subplot(111) for v1 in np.unique(real_values_prop1): for v2 in np.unique(real_values_prop2): mask = (real_values_prop1 == v1) & (real_values_prop2 == v2) indices = np.nonzero(mask)[0] # print("Indices", len(indices)) # Get the color from the color palette based on the v1 value # color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1))) color = np.random.rand( 3, ) # # Plot scatter plot with the specified color and label x_pred = predicted_context_dict[context_col[0]][indices].ravel() y_pred = predicted_context_dict[context_col[1]][indices].ravel() mse_vals_x.extend((x_pred - v1) ** 2) mad_vals_x.extend(np.abs(x_pred - v1)) mse_vals_y.extend((y_pred - v2) ** 2) mad_vals_y.extend(np.abs(y_pred - v2)) ax.scatter(x_pred, y_pred, color=color, alpha=0.5) # Plot KDE plot with the specified color # sns.kdeplot( # data=pd.DataFrame( # { # f"x": x_pred, # f"y": y_pred, # } # ), # x=f"x", # y=f"y", # color=color, # fill=False, # bw_adjust=2.25, # # label=f"({v1}, {v2})" # ) ax.scatter(v1, v2, color=color, label=f"({v1}, {v2})", marker="^", s=20.0) mse_x = np.mean(mse_vals_x) mad_x = np.mean(mad_vals_x) mse_y = np.mean(mse_vals_y) mad_y = np.mean(mad_vals_y) logger.info(f"MSE {context_col[0]}: {mse_x}") logger.info(f"MAD {context_col[0]}: {mad_x}") logger.info(f"MSE {context_col[1]}: {mse_y}") logger.info(f"MAD {context_col[1]}: {mad_y}") file_path = os.path.join(save_path, "metrics.txt") with open(file_path, "w") as f: f.write(f"MSE {context_col[0]}: {mse_x} \n") f.write(f"MAD {context_col[0]}: {mad_x} \n") f.write(f"MSE {context_col[1]}: {mse_y} \n") f.write(f"MAD {context_col[1]}: {mad_y} \n") ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]]) ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]]) box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) # Put a legend to the right of the current axis ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) ax.set_title("Multi Property Distribution of Generated Molecules") out_path = os.path.join(save_path, "graph.png") logger.info(f"Saved to {out_path}") plt.savefig(out_path) plt.clf() return save_path def plot_3D_condition( context_col, save_path, new_context, generated_smiles, temperature ): save_path = os.path.join( save_path, f"multicond3_{'-'.join(context_col)}_temp={temperature}" ) os.makedirs(save_path, exist_ok=True) predicted_context_dict = {} for con_col in context_col: predicted_context = calc_context_from_smiles(generated_smiles, con_col) predicted_context_dict[con_col] = np.array(predicted_context) real_values_prop1 = new_context[context_col[0]].cpu().numpy() real_values_prop2 = new_context[context_col[1]].cpu().numpy() real_values_prop3 = new_context[context_col[2]].cpu().numpy() # cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib mse_vals_x = [] mad_vals_x = [] mse_vals_y = [] mad_vals_y = [] mse_vals_z = [] mad_vals_z = [] fig = plt.figure() ax = fig.add_subplot(projection="3d") for v1 in np.unique(real_values_prop1): for v2 in np.unique(real_values_prop2): for v3 in np.unique(real_values_prop3): mask = ( (real_values_prop1 == v1) & (real_values_prop2 == v2) & (real_values_prop3 == v3) ) indices = np.nonzero(mask)[0] # print("Indices", len(indices)) # Get the color from the color palette based on the v1 value # color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1))) color = np.random.rand( 3, ) x_pred = predicted_context_dict[context_col[0]][indices].ravel() y_pred = predicted_context_dict[context_col[1]][indices].ravel() z_pred = predicted_context_dict[context_col[2]][indices].ravel() mse_vals_x.extend((x_pred - v1) ** 2) mad_vals_x.extend(np.abs(x_pred - v1)) mse_vals_y.extend((y_pred - v2) ** 2) mad_vals_y.extend(np.abs(y_pred - v2)) mse_vals_z.extend((z_pred - v3) ** 2) mad_vals_z.extend(np.abs(z_pred - v3)) # # Plot scatter plot with the specified color and label ax.scatter(v1, v2, v3, color=color, label=f"({v1}, {v2}, {v3})", s=20.0) ax.scatter( x_pred, y_pred, z_pred, color=color, ) mse_x = np.mean(mse_vals_x) mad_x = np.mean(mad_vals_x) mse_y = np.mean(mse_vals_y) mad_y = np.mean(mad_vals_y) mse_z = np.mean(mse_vals_z) mad_z = np.mean(mad_vals_z) logger.info(f"MSE {context_col[0]}: {mse_x}") logger.info(f"MAD {context_col[0]}: {mad_x}") logger.info(f"MSE {context_col[1]}: {mse_y}") logger.info(f"MAD {context_col[1]}: {mad_y}") logger.info(f"MSE {context_col[2]}: {mse_z}") logger.info(f"MAD {context_col[2]}: {mad_z}") file_path = os.path.join(save_path, "metrics.txt") with open(file_path, "w") as f: f.write(f"MSE {context_col[0]}: {mse_x} \n") f.write(f"MAD {context_col[0]}: {mad_x} \n") f.write(f"MSE {context_col[1]}: {mse_y} \n") f.write(f"MAD {context_col[1]}: {mad_y} \n") f.write(f"MSE {context_col[2]}: {mse_z} \n") f.write(f"MAD {context_col[2]}: {mad_z} \n") ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]]) ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]]) ax.set_zlabel(COL_TO_DISPLAY_NAME[context_col[2]]) # plt.legend( # bbox_to_anchor=(1.0, 0.5), # loc="center right", # bbox_transform=plt.gcf().transFigure, # ) # plt.subplots_adjust(left=0.05, bottom=0.1, right=0.8) plt.legend( bbox_to_anchor=(1.035, 0.5), loc="center right", bbox_transform=plt.gcf().transFigure, ) plt.subplots_adjust(left=0.05, bottom=0.1, right=0.775) plt.title("Multi Property Distribution of Generated Molecules") out_path = os.path.join(save_path, "graph.png") print(f"Saved to {out_path}") plt.savefig(out_path) plt.clf() return save_path def calc_context_from_smiles(generated_smiles, con_col): if con_col == "mol_weight": predicted_context = calcContextMolWeight(generated_smiles) elif con_col == "logp": predicted_context = calcContextLogP(generated_smiles) elif con_col == "sascore": predicted_context = calcContextSAScore(generated_smiles) elif con_col == "energy": # TODO: Change to something better predicted_context = calcContextEnergy(generated_smiles) return predicted_context def plot_unconditional( out_path: str = os.getcwd(), smiles: List[str] = [], temperature: float = 0.8, cmp_context_dict: Union[Dict[str, np.array], None] = None, context_cols: List[str] = ["logp", "sascore", "mol_weight"], ): out_path = os.path.join(out_path, "unconditional") os.makedirs(out_path, exist_ok=True) for c in context_cols: plt.clf() context_cal = calc_context_from_smiles(smiles, c) if cmp_context_dict is not None: sns.histplot( cmp_context_dict[c], stat="density", label="Dataset Distribution", alpha=0.75, color="blue", ) sns.histplot( context_cal, stat="density", label="Generated Molecules Distribution", alpha=0.5, color="orange", ) if c == "logp": plt.xlim((-6, 8)) else: plt.xlim((0, 10)) plt.xlabel(COL_TO_DISPLAY_NAME[c]) plt.title( f"Unconditional Distribution {COL_TO_DISPLAY_NAME[c]} \nwith Temperature {temperature}" ) plt.legend() out_file = os.path.join(out_path, f"unc_{c}_temp={temperature}.png") plt.savefig(out_file) logger.info(f"Saved Unconditional to {out_file}") def novelty(gen, train): gen_smiles_set = set(gen) - {None} train_set = set(train) return len(gen_smiles_set - train_set) / len(gen_smiles_set) def unique_at(gen, k=1000): gen = gen[:k] return len(set(gen)) / len(gen) def check_metrics(generated_smiles: List[str], dataset_smiles: List[str]): len_before = len(generated_smiles) generated_smiles = [g for g in generated_smiles if g is not None] len_after = len(generated_smiles) novel = novelty(generated_smiles, dataset_smiles) unique_at_1k = unique_at(generated_smiles, k=1000) unique_at_10k = unique_at(generated_smiles, k=10000) return dict( novelty=novel, unique_at_1k=unique_at_1k, unique_at_10k=unique_at_10k, validity=len_after / float(len_before), )