import gradio as gr import torch import uuid from mario_gpt.dataset import MarioDataset from mario_gpt.prompter import Prompter from mario_gpt.lm import MarioLM from mario_gpt.utils import view_level, convert_level_to_png from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import os import uvicorn mario_lm = MarioLM() device = torch.device('cuda') mario_lm = mario_lm.to(device) TILE_DIR = "data/tiles" app = FastAPI() def make_html_file(generated_level): level_text = f"""{''' '''.join(view_level(generated_level,mario_lm.tokenizer))}""" unique_id = uuid.uuid1() with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f: f.write(f''' Mario Game ''') return f"demo-{unique_id}.html" def trim_level(level): mod = level.shape[-1] % 14 if mod > 0: return level[:, :-mod] return level def reset_state(seed_state): length = len(seed_state) print(f"Resetting state with {length} levels!") for _ in range(length): seed_state.pop() def _generate_level(prompts, seed, level_size, temperature): print(f"Using prompts: {prompts}") generated_levels = mario_lm.sample( prompts=prompts, num_steps=level_size, temperature=temperature, use_tqdm=True, seed = seed ) generated_levels = trim_level(generated_levels) return generated_levels def _make_gradio_html(level): filename = make_html_file(level) gradio_html = f'''

Press the arrow keys to move. Press a to run, s to jump and d to shoot fireflowers

''' return gradio_html def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400): prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"] generated_levels = _generate_level(prompts, None, level_size, temperature) level = generated_levels.squeeze().detach().cpu() img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0] return [img, _make_gradio_html(level)] def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []): NUM_SAMPLES = 2 if prompt == "": prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" prompts = [prompt] * NUM_SAMPLES seed = None if len(seed_state) > 0: seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze() level_choices = [generated_level[-level_size:] for generated_level in generated_levels] level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels] # level choices + separate images return [level_choices, *level_choice_images] def update_level_state(choice_id, level_choices, seed_state): num_choice = int(choice_id) level_choice = level_choices[num_choice] # append level choice to seed state seed_state.append(level_choice) # get new level from concatenation level = torch.cat(seed_state).squeeze() # final image and gradio html img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0] gradio_html = _make_gradio_html(level) # return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size return img, gradio_html, seed_state, None, None, None, level.shape[-1] with gr.Blocks().queue() as demo: gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)] ''') with gr.Tabs(): with gr.TabItem("Compose prompt"): with gr.Row(): pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?") enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?") with gr.Row(): blocks = gr.Radio(["little", "some", "many"], label="How many blocks?") elevation = gr.Radio(["low", "high"], label="Elevation?") with gr.TabItem("Type prompt"): text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations") level_size = gr.Number(value=1400, precision=0, label="level_size") generate_btn = gr.Button("Generate Level") reset_btn = gr.Button("Reset Level") with gr.Row(): with gr.Box(): level_play = gr.HTML() level_image = gr.Image(label="Current Level") with gr.Box(): with gr.Column(): level_choice1_image = gr.Image(label="Sample Choice 1") level_choice1_btn = gr.Button("Sample Choice 1") with gr.Column(): level_choice2_image = gr.Image(label="Sample Choice 2") level_choice2_btn = gr.Button("Sample Choice 2") current_level_size = gr.Number(0, visible=True, label="Current Level Size") seed_state = gr.State([]) state_choices = gr.State(None) image_choice_1_id = gr.Number(0, visible=False) image_choice_2_id = gr.Number(1, visible=False) # choice buttons level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size]) level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size]) # generate_btn generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image]) # reset btn reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[]) gr.Examples( examples=[ ["many", "many", "some", "high", 2.0], ["no", "some", "many", "high", 2.0], ["many", "many", "little", "low", 2.4], ["no", "no", "many", "high", 2.8], ], inputs=[pipes, enemies, blocks, elevation, temperature, level_size], outputs=[level_image, level_play], fn=initialize_generate, cache_examples=True, ) app.mount("/static", StaticFiles(directory="static", html=True), name="static") app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/") uvicorn.run(app, host="0.0.0.0", port=7860)