import gradio as gr import torch from mario_gpt.dataset import MarioDataset from mario_gpt.prompter import Prompter from mario_gpt.lm import MarioLM from mario_gpt.utils import view_level, convert_level_to_png mario_lm = MarioLM() device = torch.device('cuda') mario_lm = mario_lm.to(device) TILE_DIR = "data/tiles" def update(prompt, progress=gr.Progress(track_tqdm=True)): prompts = [prompt] generated_level = mario_lm.sample( prompts=prompts, num_steps=1399, temperature=2.0, use_tqdm=True ) img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] return img with gr.Blocks() as demo: prompt = gr.Textbox(label="Enter your MarioGPT prompt") level_image = gr.Image() btn = gr.Button("Generate level") btn.click(fn=update, inputs=prompt, outputs=level_image) pass demo.launch()