ShaderCoder / app.py
Vipitis's picture
fix iframe, add generation
46dd33b
raw
history blame
12.5 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from wgpu.utils.shadertoy import *
from wgpu.gui.offscreen import WgpuCanvas as OffscreenCanvas, run as run_offscreen
import wgpu
import time
import ctypes
import datasets
from PIL import Image
import asyncio
import numpy as np
# reimplement the Shadertoy class with offscreen canvas!
class ShadertoyCustom(Shadertoy):
def __init__(self, shader_code, resolution=(800, 450), canvas_class=WgpuCanvas, run_fn=run):
self._canvas_class = canvas_class
self._fun_fn = run_fn
super().__init__(shader_code, resolution)
self._uniform_data = UniformArray(
("mouse", "f", 4),
("resolution", "f", 3),
("time", "f", 1),
("time_delta", "f", 1),
("frame", "I", 1),
)
self._shader_code = shader_code
self._uniform_data["resolution"] = resolution + (1,)
self._prepare_render()
self._bind_events()
def _prepare_render(self):
import wgpu.backends.rs # noqa
self._canvas = self._canvas_class(title="Shadertoy", size=self.resolution, max_fps=60)
adapter = wgpu.request_adapter(
canvas=self._canvas, power_preference="high-performance"
)
self._device = adapter.request_device()
self._present_context = self._canvas.get_context()
# We use "bgra8unorm" not "bgra8unorm-srgb" here because we want to let the shader fully control the color-space.
self._present_context.configure(
device=self._device, format=wgpu.TextureFormat.bgra8unorm
)
shader_type = self.shader_type
if shader_type == "glsl":
vertex_shader_code = vertex_code_glsl
frag_shader_code = (
builtin_variables_glsl + self.shader_code + fragment_code_glsl
)
elif shader_type == "wgsl":
vertex_shader_code = vertex_code_wgsl
frag_shader_code = (
builtin_variables_wgsl + self.shader_code + fragment_code_wgsl
)
vertex_shader_program = self._device.create_shader_module(
label="triangle_vert", code=vertex_shader_code
)
frag_shader_program = self._device.create_shader_module(
label="triangle_frag", code=frag_shader_code
)
self._uniform_buffer = self._device.create_buffer(
size=self._uniform_data.nbytes,
usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST,
)
bind_group_layout = self._device.create_bind_group_layout(
entries=binding_layout
)
self._bind_group = self._device.create_bind_group(
layout=bind_group_layout,
entries=[
{
"binding": 0,
"resource": {
"buffer": self._uniform_buffer,
"offset": 0,
"size": self._uniform_data.nbytes,
},
},
],
)
self._render_pipeline = self._device.create_render_pipeline(
layout=self._device.create_pipeline_layout(
bind_group_layouts=[bind_group_layout]
),
vertex={
"module": vertex_shader_program,
"entry_point": "main",
"buffers": [],
},
primitive={
"topology": wgpu.PrimitiveTopology.triangle_list,
"front_face": wgpu.FrontFace.ccw,
"cull_mode": wgpu.CullMode.none,
},
depth_stencil=None,
multisample=None,
fragment={
"module": frag_shader_program,
"entry_point": "main",
"targets": [
{
"format": wgpu.TextureFormat.bgra8unorm,
"blend": {
"color": (
wgpu.BlendFactor.one,
wgpu.BlendFactor.zero,
wgpu.BlendOperation.add,
),
"alpha": (
wgpu.BlendFactor.one,
wgpu.BlendFactor.zero,
wgpu.BlendOperation.add,
),
},
},
],
},
)
def show(self, time: float = 0.0):
self._canvas.request_draw(self._draw_frame)
self._fun_fn()
text = """
# Welcome to the interactive shadercoding demo.
## (WIP), you can try and explore the dataset a bit right now. (frames are rendered on the fly, not part of the dataset(yet))
This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that const of a single pass (and have at least one fuction with a return statement) are available.
In the near future there will be some buttons and sliders to generate variations of the shadercode itself, and hence get some different images.
If I find an efficient way, the shaders might run in real time and be interactive.
## TODO:
- [x] use embedded Shadertoy for reference/attribution (done, but some errors)
- [] working render implementation on CPU only space (use the browser for WebGPU?)
- [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (missing all of the generation parameters)
- [] generation history stating which function and orig/generated returns. (use State ??).
- [] generate whole functions
- [] generate whole shaders (via prompts?)
"""
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1 and x["code"].count("return") >= 1) #filter easier than having a custom loader script?
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
num_samples = len(all_single_passes)
async def get_image(code, time= 0.0, resolution=(512, 420)):
shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here.
shader._uniform_data["time"] = time #set any time you want
shader._canvas.request_draw(shader._draw_frame)
# frame = shader._canvas.snapshot().data
frame = np.asarray(shader._canvas.draw())
img = Image.fromarray(frame)
# remove transparent pixels
img = img.convert('RGB')
return img
def grab_sample(sample_idx):
sample_pass = all_single_passes[sample_idx]
sample_code = sample_pass["code"]
sample_source = sample_pass["source"]
sample_title = sample_pass["title"]
sample_auhtor = sample_pass["author"]
source_iframe = construct_embed(sample_source)
print(f"{source_iframe=}")
return sample_pass, sample_code, source_iframe #, sample_title, sample_auhtor
PIPE = None #gloabl var in CAPS indicates constant? so why are we changing it?
def _make_pipeline(model_cp = "gpt2"): #bad default model for testing
tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
PIPE = pipe # set the global?
print(f"loaded model {model_cp} as a pipline")
return pipe
def process_retn(retn):
return retn.split(";")[0].strip()
def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str:
"""
Batches the generated return statement into the code and returns the full altered code.
"""
print(f"{orig_code[retn_start_idx:retn_end_idx]=}")
generated = process_retn(prediction)
print(f"{generated=}")
variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:]
return variation
def alter_return(orig_code, func_idx=0, pipeline=PIPE): #default pipeline can't be passed as gloabl?
"""
Replaces the return statement of a function with a generated one.
Args:
orig_code (str): The original code.
func_idx (int): The index of the function to replace the return statement of.
pipeline (Pipeline): The pipeline to use for generation.
Returns:
str: The altered code.
"""
if pipeline is None:
print("no pipeline found, loading default one")
pipeline = _make_pipeline()
retrns = []
retrn_start_idx = orig_code.find("return")
while retrn_start_idx != -1:
retrn_end_idx = orig_code.find(";", retrn_start_idx)
retrns.append((retrn_start_idx, retrn_end_idx))
retrn_start_idx = orig_code.find("return", retrn_end_idx)
num_returns = len(retrns)
if num_returns == 0:
print("no return statement found, returning original code")
return orig_code
func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge.
retrn_start_idx, retrn_end_idx = retrns[func_idx]
model_context = orig_code[:retrn_start_idx] #TODO: maximal context?
model_inp = model_context + "return"
new_toks = (retrn_end_idx - retrn_start_idx) * 2 #TODO: approximation, we do have early stopping? maybe also use a number instead?
pipe_generation = pipeline(model_inp, max_new_tokens=new_toks, return_full_text=False)[0]["generated_text"] #pipeline kwargs are missing?!
altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation)
return altered_code
def add_history(func_id, orig_rtn, gened_rtn, history):
# is this a list? or a JSON dict?
history[func_id] = (orig_rtn, gened_rtn)
return history, history
def construct_embed(source_url):
shader_id = source_url.split("/")[-1]
return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'
with gr.Blocks() as site:
text_md = gr.Markdown(text)
model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
run_button = gr.Button("generate a alternate return statement for one function", label="generate code")
render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame")
time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02)
#output = gr.Textbox(label="Output")
rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview")
# info_md = gr.Markdown(value="code_source", label="source URL for this shader", interactive=False)
source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders")
sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None, readonly=True, lines=20)
sample_pass = gr.State(value={})
pipe = gr.State(value=PIPE)
# hist_state = gr.State(Value={})
# history_table = gr.JSON()
model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe])
sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed])
run_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code])
# run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state])
# sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame)
# run_button.click(fn=print, inputs=[model_cp, sample_idx], outputs=output)
site.launch()