Vipitis commited on
Commit
791c9fd
1 Parent(s): 829134c

cleanup body generation function

Browse files
Files changed (1) hide show
  1. app.py +16 -38
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import datasets
4
- import asyncio
5
  import numpy as np
6
  import torch
7
  from threading import Thread
8
 
9
- from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx
10
  from utils.html_utils import make_iframe, construct_embed
11
  PIPE = None
12
 
@@ -86,7 +85,7 @@ def grab_sample(sample_idx):
86
  # funcs = _parse_functions(sample_code)
87
  # func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
88
  # print(f"updating drop down to:{func_identifiers}")
89
- return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
90
 
91
  def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
92
  # if torch.cuda.is_available():
@@ -205,15 +204,14 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
205
  old_code (str): The original code.
206
  func_node (Node): The node of the function to replace the body of.
207
  funcs_list (list): The list of all functions in the code.
208
- prompt (str): The prompt(title) to use for generation.
209
- temperature (float): The temperature to use for generation.
210
- max_new_tokens (int): The maximum number of tokens to generate.
211
- top_p (float): The top_p to use for generation.
212
- repetition_penalty (float): The repetition_penalty to use for generation.
213
  pipeline (Pipeline): The pipeline to use for generation.
214
  Returns:
215
  str: The altered code.
216
- pipeline (Pipeline): The pipeline to update the state
217
  """
218
  if isinstance(func_id, str):
219
  print(f"{func_id=}")
@@ -226,22 +224,13 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
226
  print(f"using for generation: {func_node=}")
227
 
228
  generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
229
-
230
- print(f"{pipeline=}") # check if default even loaded
231
- if pipeline is None:
232
- print("no pipeline found, loading default one")
233
- pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine")
234
 
235
  func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
236
  identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
237
  body_node = func_node.child_by_field_name("body")
238
- body_start_idx, body_end_idx = node_str_idx(body_node) #can cause index error, needs better testing!
239
- # body_start_idx = line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
240
- # body_end_idx = line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
241
- print(f"{old_code[body_start_idx:body_end_idx]=}")
242
  model_context = identifier_str # base case
243
- # add any comments at the beginning of the function to the model_context
244
- # second_child = func_node.child_by_field_name("body").children[1] #might error out?
245
  docstring = get_docstrings(func_node) #might be empty?
246
  if docstring:
247
  model_context = model_context + "\n" + docstring
@@ -255,31 +244,23 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
255
  generation = _run_generation(model_context, pipeline, generation_kwargs)
256
  for i in generation:
257
  print(f"{i=}")
258
- yield model_context + i, pipeline #fix in between, do all the stuff in the end?
259
  generation = i[:] #seems to work
260
  print(f"{generation=}")
261
  ctx_with_generation = model_context + generation
262
- print(f"{ctx_with_generation=}")
263
  try:
264
  #strip the body
265
  first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
266
  except IndexError:
267
  print("generation wasn't a full function.")
268
  altered_code = old_code[:func_start_idx] + model_context + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
269
- return altered_code, pipeline
270
- # raise gr.Error(f"didn't generate a full function: {generation!r}]")
271
- # print(f"{first_gened_func=}")
272
- generated_body = first_gened_func.child_by_field_name("body").text.decode()
273
- print(f"{generated_body=}")
274
- altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
275
- print(f"{altered_code=}") #we get here successfully
276
- yield altered_code, pipeline #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
277
- return altered_code, pipeline #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook
278
 
279
  def list_dropdown(in_code): #only used for auto update, not on sample pick?
280
  funcs = parse_functions(in_code)
281
-
282
- # print(f"updating drop down to:{func_identifiers=}")
283
  func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
284
  # funcs = [n for n in funcs] #wrapped as set to avoid json issues?
285
  print(f"updating drop down to:{func_identifiers}")
@@ -349,17 +330,14 @@ if __name__ == "__main__": #works on huggingface?
349
  sample_code = gr.Code(new_shadertoy_code, label="Current Code (will update changes you generate)", language=None)
350
  bot_md = gr.Markdown(outro_text)
351
  sample_pass = gr.State(value={})
 
352
  pipe = gr.State(value=PIPE)
353
  pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this?
354
- funcs = gr.State(value=[])
355
- # funcs.value.append(list_dropdown(sample_code.value)[0]) #to circumvent the json issue?
356
- # hist_state = gr.State(Value={})
357
- # history_table = gr.JSON()
358
 
359
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
360
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
361
  gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code])
362
- gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
363
  fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
364
  )
365
  sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]).then(
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import datasets
 
4
  import numpy as np
5
  import torch
6
  from threading import Thread
7
 
8
+ from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
9
  from utils.html_utils import make_iframe, construct_embed
10
  PIPE = None
11
 
 
85
  # funcs = _parse_functions(sample_code)
86
  # func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
87
  # print(f"updating drop down to:{func_identifiers}")
88
+ return sample_pass, sample_code, sample_title, source_iframe#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
89
 
90
  def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
91
  # if torch.cuda.is_available():
 
204
  old_code (str): The original code.
205
  func_node (Node): The node of the function to replace the body of.
206
  funcs_list (list): The list of all functions in the code.
207
+ prompt (str): The prompt(title) to use for generation. defaults to "".
208
+ temperature (float): The temperature to use for generation. defaults to 0.2.
209
+ max_new_tokens (int): The maximum number of tokens to generate. defaults to 512.
210
+ top_p (float): The top_p to use for generation. defaults to 0.95.
211
+ repetition_penalty (float): The repetition_penalty to use for generation. defaults to 1.2.
212
  pipeline (Pipeline): The pipeline to use for generation.
213
  Returns:
214
  str: The altered code.
 
215
  """
216
  if isinstance(func_id, str):
217
  print(f"{func_id=}")
 
224
  print(f"using for generation: {func_node=}")
225
 
226
  generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
 
 
 
 
 
227
 
228
  func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
229
  identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
230
  body_node = func_node.child_by_field_name("body")
231
+ body_start_idx, body_end_idx = node_str_idx(body_node)
 
 
 
232
  model_context = identifier_str # base case
233
+
 
234
  docstring = get_docstrings(func_node) #might be empty?
235
  if docstring:
236
  model_context = model_context + "\n" + docstring
 
244
  generation = _run_generation(model_context, pipeline, generation_kwargs)
245
  for i in generation:
246
  print(f"{i=}")
247
+ yield model_context + i #fix in between, do all the stuff in the end?
248
  generation = i[:] #seems to work
249
  print(f"{generation=}")
250
  ctx_with_generation = model_context + generation
 
251
  try:
252
  #strip the body
253
  first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
254
  except IndexError:
255
  print("generation wasn't a full function.")
256
  altered_code = old_code[:func_start_idx] + model_context + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
257
+ return altered_code
258
+ altered_code = replace_function(func_node, first_gened_func)
259
+ yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
260
+ return altered_code #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook
 
 
 
 
 
261
 
262
  def list_dropdown(in_code): #only used for auto update, not on sample pick?
263
  funcs = parse_functions(in_code)
 
 
264
  func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
265
  # funcs = [n for n in funcs] #wrapped as set to avoid json issues?
266
  print(f"updating drop down to:{func_identifiers}")
 
330
  sample_code = gr.Code(new_shadertoy_code, label="Current Code (will update changes you generate)", language=None)
331
  bot_md = gr.Markdown(outro_text)
332
  sample_pass = gr.State(value={})
333
+ funcs = gr.State(value=[])
334
  pipe = gr.State(value=PIPE)
335
  pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this?
 
 
 
 
336
 
337
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
338
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
339
  gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code])
340
+ gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code]).then(
341
  fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
342
  )
343
  sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]).then(