Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from mario_gpt.dataset import MarioDataset | |
from mario_gpt.prompter import Prompter | |
from mario_gpt.lm import MarioLM | |
from mario_gpt.utils import view_level, convert_level_to_png | |
mario_lm = MarioLM() | |
device = torch.device('cuda') | |
mario_lm = mario_lm.to(device) | |
TILE_DIR = "data/tiles" | |
def update(prompt, progress=gr.Progress(track_tqdm=True)): | |
prompts = [prompt] | |
generated_level = mario_lm.sample( | |
prompts=prompts, | |
num_steps=1399, | |
temperature=2.0, | |
use_tqdm=True | |
) | |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] | |
return img | |
with gr.Blocks() as demo: | |
prompt = gr.Textbox(label="Enter your MarioGPT prompt") | |
level_image = gr.Image() | |
btn = gr.Button("Generate level") | |
btn.click(fn=update, inputs=prompt, outputs=level_image) | |
pass | |
demo.launch() | |