multimodalart HF staff shyamsn97 commited on
Commit
1ec2ec6
1 Parent(s): 141b1fb

Adding radio inputs + adding additional parameters (#2)

Browse files

- Adding radio inputs + adding additional parameters (8d86ca5eb3f213423a2d767f32d91ba8d2225b57)


Co-authored-by: Shyam Sudhakaran <[email protected]>

Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -11,26 +11,46 @@ device = torch.device('cuda')
11
  mario_lm = mario_lm.to(device)
12
  TILE_DIR = "data/tiles"
13
 
14
- def update(prompt):
 
 
 
 
 
15
  prompts = [prompt]
16
  generated_level = mario_lm.sample(
17
  prompts=prompts,
18
- num_steps=1399,
19
- temperature=2.0,
20
  use_tqdm=True
21
  )
22
  img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
23
- return img
24
 
25
  with gr.Blocks() as demo:
26
- gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt)")
27
- prompt = gr.Textbox(label="Enter your MarioGPT prompt")
 
 
28
  level_image = gr.Image()
29
  btn = gr.Button("Generate level")
30
- btn.click(fn=update, inputs=prompt, outputs=level_image)
 
 
 
 
 
 
 
 
31
  gr.Examples(
32
- examples=["many pipes, many enemies, some blocks, high elevation", "little pipes, little enemies, many blocks, high elevation", "many pipes, some enemies", "no pipes, no enemies, many blocks"],
33
- inputs=prompt,
 
 
 
 
 
34
  outputs=level_image,
35
  fn=update,
36
  cache_examples=True,
 
11
  mario_lm = mario_lm.to(device)
12
  TILE_DIR = "data/tiles"
13
 
14
+
15
+
16
+ def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
17
+ if prompt == "":
18
+ prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
19
+ print(f"Using prompt: {prompt}")
20
  prompts = [prompt]
21
  generated_level = mario_lm.sample(
22
  prompts=prompts,
23
+ num_steps=level_size,
24
+ temperature=temperature,
25
  use_tqdm=True
26
  )
27
  img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
28
+ return img
29
 
30
  with gr.Blocks() as demo:
31
+ gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!")
32
+
33
+ text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
34
+
35
  level_image = gr.Image()
36
  btn = gr.Button("Generate level")
37
+
38
+ pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
39
+ enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
40
+ blocks = gr.Radio(["little", "some", "many"], label="blocks")
41
+ elevation = gr.Radio(["low", "high"], label="elevation")
42
+ temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations")
43
+ level_size = gr.Number(value=1399, precision=0, label="level_size")
44
+
45
+ btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image)
46
  gr.Examples(
47
+ examples=[
48
+ ["many", "many", "some", "high"],
49
+ ["no", "some", "many", "high", 2.0],
50
+ ["many", "many", "little", "low", 2.0],
51
+ ["no", "no", "many", "high", 2.4],
52
+ ],
53
+ inputs=[pipes, enemies, blocks, elevation],
54
  outputs=level_image,
55
  fn=update,
56
  cache_examples=True,