import gradio as gr
import torch
import uuid
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
app = FastAPI()
def make_html_file(generated_level):
level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
unique_id = uuid.uuid1()
with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
f.write(f'''
Mario Game
''')
return f"demo-{unique_id}.html"
def trim_level(level):
mod = level.shape[-1] % 14
if mod > 0:
return level[:, :-mod]
return level
def reset_state(seed_state):
length = len(seed_state)
print(f"Resetting state with {length} levels!")
for _ in range(length):
seed_state.pop()
def _generate_level(prompts, seed, level_size, temperature):
print(f"Using prompts: {prompts}")
generated_levels = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=temperature,
use_tqdm=True,
seed = seed
)
generated_levels = trim_level(generated_levels)
return generated_levels
def _make_gradio_html(level):
filename = make_html_file(level)
gradio_html = f'''
Press the arrow keys to move. Press a
to run, s
to jump and d
to shoot fireflowers
'''
return gradio_html
def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400):
prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"]
generated_levels = _generate_level(prompts, None, level_size, temperature)
level = generated_levels.squeeze().detach().cpu()
img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
return [img, _make_gradio_html(level)]
def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []):
NUM_SAMPLES = 2
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
prompts = [prompt] * NUM_SAMPLES
seed = None
if len(seed_state) > 0:
seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length
generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze()
level_choices = [generated_level[-level_size:] for generated_level in generated_levels]
level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels]
# level choices + separate images
return [level_choices, *level_choice_images]
def update_level_state(choice_id, level_choices, seed_state):
num_choice = int(choice_id)
level_choice = level_choices[num_choice]
# append level choice to seed state
seed_state.append(level_choice)
# get new level from concatenation
level = torch.cat(seed_state).squeeze()
# final image and gradio html
img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = _make_gradio_html(level)
# return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size
return img, gradio_html, seed_state, None, None, None, level.shape[-1]
with gr.Blocks().queue() as demo:
gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
elevation = gr.Radio(["low", "high"], label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Number(value=1400, precision=0, label="level_size")
generate_btn = gr.Button("Generate Level")
reset_btn = gr.Button("Reset Level")
with gr.Row():
with gr.Box():
level_play = gr.HTML()
level_image = gr.Image(label="Current Level")
with gr.Box():
with gr.Column():
level_choice1_image = gr.Image(label="Sample Choice 1")
level_choice1_btn = gr.Button("Sample Choice 1")
with gr.Column():
level_choice2_image = gr.Image(label="Sample Choice 2")
level_choice2_btn = gr.Button("Sample Choice 2")
current_level_size = gr.Number(0, visible=True, label="Current Level Size")
seed_state = gr.State([])
state_choices = gr.State(None)
image_choice_1_id = gr.Number(0, visible=False)
image_choice_2_id = gr.Number(1, visible=False)
# choice buttons
level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
# generate_btn
generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image])
# reset btn
reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[])
gr.Examples(
examples=[
["many", "many", "some", "high", 2.0],
["no", "some", "many", "high", 2.0],
["many", "many", "little", "low", 2.4],
["no", "no", "many", "high", 2.8],
],
inputs=[pipes, enemies, blocks, elevation, temperature, level_size],
outputs=[level_image, level_play],
fn=initialize_generate,
cache_examples=True,
)
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
uvicorn.run(app, host="0.0.0.0", port=7860)