Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import os | |
import ray | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
#from transformers import pipeline as pl | |
from transformers import GPT2LMHeadModel , GPT2Tokenizer | |
from GPUtil import showUtilization as gpu_usage | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
import plotly.graph_objects as go | |
import torch | |
import gc | |
import jax | |
from numba import cuda | |
import math | |
print('GPU available',torch.cuda.is_available()) | |
#print('__CUDA Device Name:',torch.cuda.get_device_name(0)) | |
print(os.getcwd()) | |
if "/home/user/app/alphafold" not in sys.path: | |
sys.path.append("/home/user/app/alphafold") | |
from alphafold.common import protein | |
from alphafold.data import pipeline | |
from alphafold.data import templates | |
from alphafold.model import data | |
from alphafold.model import config | |
from alphafold.model import model | |
def mk_mock_template(query_sequence): | |
"""create blank template""" | |
ln = len(query_sequence) | |
output_templates_sequence = "-" * ln | |
templates_all_atom_positions = np.zeros( | |
(ln, templates.residue_constants.atom_type_num, 3) | |
) | |
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num)) | |
templates_aatype = templates.residue_constants.sequence_to_onehot( | |
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID | |
) | |
template_features = { | |
"template_all_atom_positions": templates_all_atom_positions[None], | |
"template_all_atom_masks": templates_all_atom_masks[None], | |
"template_aatype": np.array(templates_aatype)[None], | |
"template_domain_names": [f"none".encode()], | |
} | |
return template_features | |
def predict_structure(prefix, feature_dict, model_runners, random_seed=0): | |
"""Predicts structure using AlphaFold for the given sequence.""" | |
# Run the models. | |
# currently we only run model1 | |
plddts = {} | |
for model_name, model_runner in model_runners.items(): | |
processed_feature_dict = model_runner.process_features( | |
feature_dict, random_seed=random_seed | |
) | |
prediction_result = model_runner.predict(processed_feature_dict) | |
b_factors = ( | |
prediction_result["plddt"][:, None] | |
* prediction_result["structure_module"]["final_atom_mask"] | |
) | |
unrelaxed_protein = protein.from_prediction( | |
processed_feature_dict, prediction_result, b_factors | |
) | |
unrelaxed_pdb_path = f"/home/user/app/{prefix}_unrelaxed_{model_name}.pdb" | |
plddts[model_name] = prediction_result["plddt"] | |
print(f"{model_name} {plddts[model_name].mean()}") | |
with open(unrelaxed_pdb_path, "w") as f: | |
f.write(protein.to_pdb(unrelaxed_protein)) | |
return plddts | |
def compute_perplexity(model, tokenizer, sequence): | |
input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(input_ids, labels=input_ids) | |
loss, logits = outputs[:2] | |
return math.exp(loss) | |
def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs): | |
print("running protgpt2") | |
print(gpu_usage()) | |
seqs_to_sample = max_seqs*10 # get the top 10 | |
#protgpt2 = pl("text-generation", model="nferruz/ProtGPT2") | |
model = GPT2LMHeadModel.from_pretrained("nferruz/ProtGPT2") | |
tokenizer = GPT2Tokenizer.from_pretrained("nferruz/ProtGPT2") | |
input_ids = tokenizer.encode(startsequence, return_tensors='pt') | |
sequences = model.generate(input_ids, | |
max_length=length, | |
do_sample=True, | |
top_k=top_k_poolsize, | |
repetition_penalty=repetitionPenalty, | |
num_return_sequences=seqs_to_sample, | |
eos_token_id=0) | |
filtered_sequences = [] | |
for sequence in sequences: | |
decoded_seq = tokenizer.decode(seq) | |
# No newlines in first line and avoid truncation | |
if '\n' not in decoded_seq[0:60] and decoded_seq.count('<|endoftext|>')>=2: | |
clean_seq = decoded_seq.split('<|endoftext|>')[0] | |
ppl = compute_perplexity(model, tokenizer, clean_seq) | |
filtered_sequences.append((clean_seq, ppl/len(clean_seq))) | |
## THis needs to be fixed to show warning if not enough sequences fulfill the criteria! | |
selected_sequences = filtered_sequences.sort(key = lambda x: x[2])[:max_seqs] | |
# sequences = protgpt2( | |
# startsequence, | |
# max_length=length, | |
# do_sample=True, | |
# top_k=top_k_poolsize, | |
# repetition_penalty=repetitionPenalty, | |
# num_return_sequences=seqs_to_sample, | |
# eos_token_id=0, | |
# ) | |
print("Cleaning up after protGPT2") | |
#print(gpu_usage()) | |
#torch.cuda.empty_cache() | |
#device = cuda.get_current_device() | |
#device.reset() | |
#print(gpu_usage()) | |
return selected_sequences | |
def run_alphafold(startsequence): | |
print(gpu_usage()) | |
model_runners = {} | |
models = ["model_1"] # ,"model_2","model_3","model_4","model_5"] | |
for model_name in models: | |
model_config = config.model_config(model_name) | |
model_config.data.eval.num_ensemble = 1 | |
model_params = data.get_model_haiku_params(model_name=model_name, data_dir="/home/user/app/") | |
model_runner = model.RunModel(model_config, model_params) | |
model_runners[model_name] = model_runner | |
query_sequence = startsequence.replace("\n", "") | |
feature_dict = { | |
**pipeline.make_sequence_features( | |
sequence=query_sequence, description="none", num_res=len(query_sequence) | |
), | |
**pipeline.make_msa_features( | |
msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]] | |
), | |
**mk_mock_template(query_sequence), | |
} | |
plddts = predict_structure("test", feature_dict, model_runners) | |
print("AF2 done") | |
#backend = jax.lib.xla_bridge.get_backend() | |
#for buf in backend.live_buffers(): buf.delete() | |
#device = cuda.get_current_device() | |
#device.reset() | |
#print(gpu_usage()) | |
return plddts["model_1"] | |
def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs): | |
startsequence = inp | |
seqlen = length | |
generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs)) | |
gen_seqs = [x["generated_text"] for x in generated_seqs] | |
# Make sure sequences weren't truncated due to the length cutoff | |
# Select the best scoring top 10th: | |
print(sel_seqs) | |
sequencestxt = "" | |
for i, seq in enumerate(sel_seqs): | |
s = seq.replace("\n","") | |
seqlen = len(s) | |
s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)]) | |
sequencestxt +=f">seq{i}, {seqlen} residues \n{s}\n\n" | |
return sequencestxt | |
def update(inp): | |
print("Running AF on", inp) | |
startsequence = inp | |
# run alphafold using ray | |
plddts = ray.get(run_alphafold.remote(startsequence)) | |
print(plddts) | |
x = np.arange(10) | |
#plt.style.use(["seaborn-ticks", "seaborn-talk"]) | |
#fig = plt.figure() | |
#ax = fig.add_subplot(111) | |
#ax.plot(plddts) | |
#ax.set_ylabel("predicted LDDT") | |
#ax.set_xlabel("positions") | |
#ax.set_title("pLDDT") | |
fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}')) | |
fig.update_layout(title="pLDDT", | |
xaxis_title="Residue index", | |
yaxis_title="pLDDT", | |
height=500, | |
template="simple_white") | |
return ( | |
molecule( | |
f"test_unrelaxed_model_1.pdb", | |
), | |
fig, | |
f"{np.mean(plddts):.1f} Β± {np.std(plddts):.1f}", | |
) | |
def read_mol(molpath): | |
with open(molpath, "r") as fp: | |
lines = fp.readlines() | |
mol = "" | |
for l in lines: | |
mol += l | |
return mol | |
def molecule(pdb): | |
mol = read_mol(pdb) | |
x = ( | |
"""<!DOCTYPE html> | |
<html> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
<link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/flowbite.min.css" /> | |
<style> | |
body{ | |
font-family:sans-serif | |
} | |
.mol-container { | |
width: 100%; | |
height: 800px; | |
position: relative; | |
} | |
.space-x-2 > * + *{ | |
margin-left: 0.5rem; | |
} | |
.p-1{ | |
padding:0.5rem; | |
} | |
.flex{ | |
display:flex; | |
align-items: center; | |
} | |
.w-4{ | |
width:1rem; | |
} | |
.h-4{ | |
height:1rem; | |
} | |
.mt-4{ | |
margin-top:1rem; | |
} | |
</style> | |
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
</head> | |
<body> | |
<div id="container" class="mol-container"></div> | |
<div class="flex"> | |
<div class="px-4"> | |
<label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer "> | |
<input id="sidechain"type="checkbox" class="sr-only peer"> | |
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div> | |
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span> | |
</label> | |
</div> | |
<button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download"> | |
<svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg> | |
Download predicted structure | |
</button> | |
</div> | |
<div class="text-sm"> | |
<div class="font-medium mt-4"><b>AlphaFold model confidence:</b></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
style="background-color: rgb(0, 83, 214);"> </span><span class="legendlabel">Very high | |
(pLDDT > 90)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
style="background-color: rgb(101, 203, 243);"> </span><span class="legendlabel">Confident | |
(90 > pLDDT > 70)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
style="background-color: rgb(255, 219, 19);"> </span><span class="legendlabel">Low (70 > | |
pLDDT > 50)</span></div> | |
<div class="flex space-x-2 py-1"><span class="w-4 h-4" | |
style="background-color: rgb(255, 125, 69);"> </span><span class="legendlabel">Very low | |
(pLDDT < 50)</span></div> | |
<div class="row column legendDesc"> AlphaFold produces a per-residue confidence | |
score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation. | |
</div> | |
</div> | |
<script> | |
let viewer = null; | |
let voldata = null; | |
$(document).ready(function () { | |
let element = $("#container"); | |
let config = { backgroundColor: "white" }; | |
viewer = $3Dmol.createViewer( element, config ); | |
viewer.ui.initiateUI(); | |
let data = `""" | |
+ mol | |
+ """` | |
viewer.addModel( data, "pdb" ); | |
//AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96 | |
let colorAlpha = function (atom) { | |
if (atom.b < 50) { | |
return "OrangeRed"; | |
} else if (atom.b < 70) { | |
return "Gold"; | |
} else if (atom.b < 90) { | |
return "MediumTurquoise"; | |
} else { | |
return "Blue"; | |
} | |
}; | |
viewer.setStyle({}, { cartoon: { colorfunc: colorAlpha } }); | |
viewer.zoomTo(); | |
viewer.render(); | |
viewer.zoom(0.8, 2000); | |
viewer.getModel(0).setHoverable({}, true, | |
function (atom, viewer, event, container) { | |
console.log(atom) | |
if (!atom.label) { | |
atom.label = viewer.addLabel(atom.resn+atom.resi+" pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" }); | |
} | |
}, | |
function (atom, viewer) { | |
if (atom.label) { | |
viewer.removeLabel(atom.label); | |
delete atom.label; | |
} | |
} | |
); | |
$("#sidechain").change(function () { | |
if (this.checked) { | |
BB = ["C", "O", "N"] | |
viewer.setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }}); | |
viewer.render() | |
} else { | |
viewer.setStyle({cartoon: { colorfunc: colorAlpha }}); | |
viewer.render() | |
} | |
}); | |
$("#download").click(function () { | |
download("gradioFold_model1.pdb", data); | |
}) | |
}); | |
function download(filename, text) { | |
var element = document.createElement("a"); | |
element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text)); | |
element.setAttribute("download", filename); | |
element.style.display = "none"; | |
document.body.appendChild(element); | |
element.click(); | |
document.body.removeChild(element); | |
} | |
</script> | |
</body></html>""" | |
) | |
return f"""<iframe style="width: 800px; height: 1200px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def change_sequence(chosenSeq): | |
return chosenSeq | |
proteindream = gr.Blocks() | |
with proteindream: | |
gr.Markdown("# GradioFold") | |
gr.Markdown( | |
"""GradioFold is a web-based tool that combines a large language model trained on natural protein sequence (protGPT2) with structure prediction using AlphaFold. | |
Type a start sequence that protGPT2 can complete or let protGPT2 generate a complete sequence without a start token.""" | |
) | |
gr.Markdown("## protGPT2") | |
gr.Markdown( | |
""" | |
Enter a start sequence and have the language model complete it OR leave empty. | |
""" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
inp = gr.Textbox(placeholder="MTYKLILNGKTLKGETTT", label="Start sequence") | |
length = gr.Number(value=100, label="Max sequence length") | |
with gr.Row(): | |
repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty") | |
top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size") | |
max_seqs = gr.Slider(minimum=2, maximum=20,value=5, step=1, label="Number of sequences to generate") | |
btn = gr.Button("Predict sequences using protGPT2") | |
results = gr.Textbox(label="Results", lines=15) | |
btn.click(fn=update_protGPT2, inputs=[inp, length, repetitionPenalty, top_k_poolsize, max_seqs], outputs=results) | |
gr.Markdown("## AlphaFold") | |
gr.Markdown( | |
"Select a generated sequence above and copy it in the field below for structure prediction using AlphaFold2. You can also edit the sequence. Predictions will take around 2-5 minutes to be processed. Proteins larger than about 1000 residues will not fit into memory." | |
) | |
with gr.Group(): | |
chosenSeq = gr.Textbox(label="Chosen sequence") | |
btn2 = gr.Button("Predict structure") | |
with gr.Group(): | |
meanpLDDT = gr.Textbox(label="Mean pLDDT of chosen sequence") | |
with gr.Row(): | |
mol = gr.HTML() | |
plot = gr.Plot(label="pLDDT") | |
gr.Markdown( | |
"""## Acknowledgements | |
More information about the used algorithms can be found below. | |
All code is available on [Huggingface](https://huggingface.co/spaces/simonduerr/protGPT2_gradioFold/blob/main) and licensed under MIT license. | |
- ProtGPT2: Ferruz et.al π[BioRxiv](https://doi.org/10.1101/2022.03.09.483666) π»[Code](https://huggingface.co/nferruz/ProtGPT2) | |
- AlphaFold2: Jumper et.al π[Paper](https://doi.org/10.1038/s41586-021-03819-2) π»[Code](https://github.com/deepmind/alphafold) Model parameters released under CC BY 4.0 | |
- ColabFold: Mirdita et.al π[Paper](https://doi.org/10.1101/2021.08.15.456425 ) π»[Code](https://github.com/sokrypton/ColabFold) | |
- 3Dmol.js: Rego & Koes π[Paper](https://academic.oup.com/bioinformatics/article/31/8/1322/213186) π» [Code](https://github.com/3dmol/3Dmol.js) | |
Created by [@simonduerr](https://twitter.com/simonduerr) | |
Thanks to Hugginface team for sponsoring a free GPU for this demo. | |
""" | |
) | |
#seqChoice.change(fn=update_seqs, inputs=seqChoice, outputs=chosenSeq) | |
btn2.click(fn=update, inputs=chosenSeq, outputs=[mol, plot, meanpLDDT]) | |
ray.init(runtime_env={"working_dir": "./alphafold"}) | |
proteindream.launch(share=False) | |