import gradio as gr import torch import uuid from mario_gpt.dataset import MarioDataset from mario_gpt.prompter import Prompter from mario_gpt.lm import MarioLM from mario_gpt.utils import view_level, convert_level_to_png from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import os import uvicorn mario_lm = MarioLM() device = torch.device('cuda') mario_lm = mario_lm.to(device) TILE_DIR = "data/tiles" app = FastAPI() def make_html_file(generated_level): level_text = f"""{''' '''.join(view_level(generated_level,mario_lm.tokenizer))}""" unique_id = uuid.uuid1() with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f: f.write(f''' Mario Game ''') return f"demo-{unique_id}.html" def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""): if prompt == "": prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" print(f"Using prompt: {prompt}") prompts = [prompt] generated_level = mario_lm.sample( prompts=prompts, num_steps=level_size, temperature=temperature, use_tqdm=True ) filename = make_html_file(generated_level) img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] gradio_html = f'''

Press the arrow keys to move. Press a to run, s to jump and d to shoot fireflowers

''' return [img, gradio_html] with gr.Blocks() as demo: gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)] ''') with gr.Tabs(): with gr.TabItem("Compose prompt"): with gr.Row(): pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?") enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?") with gr.Row(): blocks = gr.Radio(["little", "some", "many"], label="How many blocks?") elevation = gr.Radio(["low", "high"], label="Elevation?") with gr.TabItem("Type prompt"): text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations") level_size = gr.Number(value=1399, precision=0, label="level_size") btn = gr.Button("Generate level") with gr.Row(): with gr.Box(): level_play = gr.HTML() level_image = gr.Image() btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play]) gr.Examples( examples=[ ["many", "many", "some", "high"], ["no", "some", "many", "high", 2.0], ["many", "many", "little", "low", 2.0], ["no", "no", "many", "high", 2.4], ], inputs=[pipes, enemies, blocks, elevation], outputs=[level_image, level_play], fn=generate, cache_examples=True, ) app.mount("/static", StaticFiles(directory="static", html=True), name="static") app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/") uvicorn.run(app, host="0.0.0.0", port=7860)