Spaces:
Runtime error
Runtime error
streaming text generation in working shape
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
import datasets
|
4 |
import asyncio
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
7 |
|
8 |
def make_script(shader_code):
|
9 |
# code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
|
@@ -274,6 +275,7 @@ outro_text ="""
|
|
274 |
- [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
|
275 |
- [] gradio examples
|
276 |
- [] use GPU if available, respect memory restrictions.
|
|
|
277 |
|
278 |
### Notes:
|
279 |
- this is meant as a resource to show code generation for a "creative" task.
|
@@ -342,6 +344,34 @@ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #
|
|
342 |
print(f"loaded model {model_cp} as a pipline")
|
343 |
return pipe
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
def process_retn(retn):
|
347 |
return retn.split(";")[0].strip()
|
@@ -458,7 +488,12 @@ def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens,
|
|
458 |
# print(second_child.text.decode())
|
459 |
model_context += " { \n " + second_child.text.decode()
|
460 |
print(f"{model_context=}")
|
461 |
-
generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
|
|
|
|
|
|
|
|
|
|
462 |
print(f"{generation=}")
|
463 |
ctx_with_generation = model_context + generation
|
464 |
print(f"{ctx_with_generation=}")
|
@@ -474,7 +509,9 @@ def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens,
|
|
474 |
generated_body = first_gened_func.child_by_field_name("body").text.decode()
|
475 |
print(f"{generated_body=}")
|
476 |
altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
|
477 |
-
|
|
|
|
|
478 |
|
479 |
def add_history(func_id, orig_rtn, gened_rtn, history):
|
480 |
# is this a list? or a JSON dict?
|
@@ -524,7 +561,7 @@ with gr.Blocks() as site:
|
|
524 |
with column_2:
|
525 |
top_p = gr.Slider(
|
526 |
label="Top-p (nucleus sampling)",
|
527 |
-
value=0.
|
528 |
minimum=0.0,
|
529 |
maximum=1,
|
530 |
step=0.05,
|
@@ -563,4 +600,4 @@ with gr.Blocks() as site:
|
|
563 |
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe])
|
564 |
sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up
|
565 |
sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways.
|
566 |
-
site.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
3 |
import datasets
|
4 |
import asyncio
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
from threading import Thread
|
8 |
|
9 |
def make_script(shader_code):
|
10 |
# code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
|
|
|
275 |
- [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
|
276 |
- [] gradio examples
|
277 |
- [] use GPU if available, respect memory restrictions.
|
278 |
+
- [~] stream model generation (maybe in a new window?) - WIP for body gen right now -> janky solution works.
|
279 |
|
280 |
### Notes:
|
281 |
- this is meant as a resource to show code generation for a "creative" task.
|
|
|
344 |
print(f"loaded model {model_cp} as a pipline")
|
345 |
return pipe
|
346 |
|
347 |
+
def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
|
348 |
+
"""
|
349 |
+
Text generation function
|
350 |
+
Args:
|
351 |
+
model_ctx (str): The context to start generation from.
|
352 |
+
pipe (Pipeline): The pipeline to use for generation.
|
353 |
+
gen_kwargs (dict): The generation kwargs.
|
354 |
+
Returns:
|
355 |
+
str: The generated text. (it iterates over time)
|
356 |
+
"""
|
357 |
+
# Tokenize the model_context
|
358 |
+
model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")
|
359 |
+
|
360 |
+
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
361 |
+
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
362 |
+
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
|
363 |
+
generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
|
364 |
+
t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
365 |
+
t.start()
|
366 |
+
|
367 |
+
# Pull the generated text from the streamer, and update the model output.
|
368 |
+
model_output = ""
|
369 |
+
for new_text in streamer:
|
370 |
+
# print("step", end="")
|
371 |
+
model_output += new_text
|
372 |
+
yield model_output
|
373 |
+
streamer.on_finalized_text("stream reached the end.")
|
374 |
+
return model_output #is this ever reached?
|
375 |
|
376 |
def process_retn(retn):
|
377 |
return retn.split(";")[0].strip()
|
|
|
488 |
# print(second_child.text.decode())
|
489 |
model_context += " { \n " + second_child.text.decode()
|
490 |
print(f"{model_context=}")
|
491 |
+
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
492 |
+
generation = _run_generation(model_context, pipeline, generation_kwargs)
|
493 |
+
for i in generation:
|
494 |
+
print(f"{i=}")
|
495 |
+
yield model_context + i, pipeline #fix in between, do all the stuff in the end?
|
496 |
+
generation = i[:] #seems to work
|
497 |
print(f"{generation=}")
|
498 |
ctx_with_generation = model_context + generation
|
499 |
print(f"{ctx_with_generation=}")
|
|
|
509 |
generated_body = first_gened_func.child_by_field_name("body").text.decode()
|
510 |
print(f"{generated_body=}")
|
511 |
altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
|
512 |
+
print(f"{altered_code=}") #we get here successfully
|
513 |
+
yield altered_code, pipeline #yield once so it updates? -> works... gg
|
514 |
+
return altered_code, pipeline #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook
|
515 |
|
516 |
def add_history(func_id, orig_rtn, gened_rtn, history):
|
517 |
# is this a list? or a JSON dict?
|
|
|
561 |
with column_2:
|
562 |
top_p = gr.Slider(
|
563 |
label="Top-p (nucleus sampling)",
|
564 |
+
value=0.85,
|
565 |
minimum=0.0,
|
566 |
maximum=1,
|
567 |
step=0.05,
|
|
|
600 |
gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe])
|
601 |
sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) # to update this after generation, so spans aren't messed up
|
602 |
sample_code.change(fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) #twice could cause issues, find better ways.
|
603 |
+
site.launch(enable_queue=True)
|