import gradio as gr import torch from mario_gpt.dataset import MarioDataset from mario_gpt.prompter import Prompter from mario_gpt.lm import MarioLM from mario_gpt.utils import view_level, convert_level_to_png mario_lm = MarioLM() device = torch.device('cuda') mario_lm = mario_lm.to(device) TILE_DIR = "data/tiles" def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""): if prompt == "": prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" print(f"Using prompt: {prompt}") prompts = [prompt] generated_level = mario_lm.sample( prompts=prompts, num_steps=level_size, temperature=temperature, use_tqdm=True ) img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] return img with gr.Blocks() as demo: gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!") text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") level_image = gr.Image() btn = gr.Button("Generate level") pipes = gr.Radio(["no", "little", "some", "many"], label="pipes") enemies = gr.Radio(["no", "little", "some", "many"], label="enemies") blocks = gr.Radio(["little", "some", "many"], label="blocks") elevation = gr.Radio(["low", "high"], label="elevation") temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations") level_size = gr.Number(value=1399, precision=0, label="level_size") btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image) gr.Examples( examples=[ ["many", "many", "some", "high"], ["no", "some", "many", "high", 2.0], ["many", "many", "little", "low", 2.0], ["no", "no", "many", "high", 2.4], ], inputs=[pipes, enemies, blocks, elevation], outputs=level_image, fn=update, cache_examples=True, ) demo.launch()