Spaces:
Runtime error
Runtime error
Commit
•
850b0e4
1
Parent(s):
b995aaa
MarioGPT first attempt
Browse files- .gitignore +1 -0
- LICENSE +21 -0
- Makefile +28 -0
- app.py +32 -0
- data/tiles/N.png +0 -0
- data/tiles/Y.png +0 -0
- data/tiles/cannon_bottom.png +0 -0
- data/tiles/cannon_top.png +0 -0
- data/tiles/flying_koopa.png +0 -0
- data/tiles/ki-background.png +0 -0
- data/tiles/ki-door.png +0 -0
- data/tiles/ki-hazard.png +0 -0
- data/tiles/ki-moving-platform.png +0 -0
- data/tiles/ki-passable.png +0 -0
- data/tiles/ki-path.png +0 -0
- data/tiles/ki-unpassable.png +0 -0
- data/tiles/mm-CMM.png +0 -0
- data/tiles/mm-DMM.png +0 -0
- data/tiles/mm-HMM.png +0 -0
- data/tiles/mm-LMM.png +0 -0
- data/tiles/mm-MMM.png +0 -0
- data/tiles/mm-TMM.png +0 -0
- data/tiles/mma_tiles.zip +3 -0
- data/tiles/plant.png +0 -0
- data/tiles/smb-background.png +0 -0
- data/tiles/smb-breakable.png +0 -0
- data/tiles/smb-coin.png +0 -0
- data/tiles/smb-enemy.png +0 -0
- data/tiles/smb-path.png +0 -0
- data/tiles/smb-question.png +0 -0
- data/tiles/smb-tube-lower-left.png +0 -0
- data/tiles/smb-tube-lower-right.png +0 -0
- data/tiles/smb-tube-top-left.png +0 -0
- data/tiles/smb-tube-top-right.png +0 -0
- data/tiles/smb-unpassable.png +0 -0
- data/tiles/smb_enemies_sheet.png +0 -0
- data/tiles/tile004 (1).png +0 -0
- data/tiles/tile004 (2).png +0 -0
- data/tiles/tile004.png +0 -0
- mario_gpt/__init__.py +0 -0
- mario_gpt/dataset.py +152 -0
- mario_gpt/level.py +0 -0
- mario_gpt/lm.py +150 -0
- mario_gpt/prompter.py +175 -0
- mario_gpt/utils.py +99 -0
- notebooks/Sampling.ipynb +349 -0
- setup.py +43 -0
- static/architecture.png +0 -0
- static/prompt-samples.png +0 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2023 Shyam Sudhakaran
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Makefile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
|
2 |
+
|
3 |
+
clean-build: ## remove build artifacts
|
4 |
+
rm -fr build/
|
5 |
+
rm -fr dist/
|
6 |
+
rm -fr .eggs/
|
7 |
+
find . -name '*.egg-info' -exec rm -fr {} +
|
8 |
+
find . -name '*.egg' -exec rm -f {} +
|
9 |
+
|
10 |
+
clean-pyc: ## remove Python file artifacts
|
11 |
+
find . -name '*.pyc' -exec rm -f {} +
|
12 |
+
find . -name '*.pyo' -exec rm -f {} +
|
13 |
+
find . -name '*~' -exec rm -f {} +
|
14 |
+
find . -name '__pycache__' -exec rm -fr {} +
|
15 |
+
|
16 |
+
clean-test: ## remove test and coverage artifacts
|
17 |
+
rm -fr .tox/
|
18 |
+
rm -f .coverage
|
19 |
+
rm -fr coverage/
|
20 |
+
rm -fr .pytest_cache
|
21 |
+
|
22 |
+
lint: ## check style with flake8
|
23 |
+
isort --profile black mario_gpt
|
24 |
+
black mario_gpt
|
25 |
+
flake8 mario_gpt
|
26 |
+
|
27 |
+
install: clean lint
|
28 |
+
python setup.py install
|
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from mario_gpt.dataset import MarioDataset
|
4 |
+
from mario_gpt.prompter import Prompter
|
5 |
+
from mario_gpt.lm import MarioLM
|
6 |
+
from mario_gpt.utils import view_level, convert_level_to_png
|
7 |
+
|
8 |
+
mario_lm = MarioLM()
|
9 |
+
|
10 |
+
device = torch.device('cuda')
|
11 |
+
mario_lm = mario_lm.to(device)
|
12 |
+
TILE_DIR = "data/tiles"
|
13 |
+
|
14 |
+
def update(prompt, progress=gr.Progress(track_tqdm=True)):
|
15 |
+
prompts = [prompt]
|
16 |
+
generated_level = mario_lm.sample(
|
17 |
+
prompts=prompts,
|
18 |
+
num_steps=1399,
|
19 |
+
temperature=2.0,
|
20 |
+
use_tqdm=True
|
21 |
+
)
|
22 |
+
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
|
23 |
+
return img
|
24 |
+
|
25 |
+
with gr.Blocks() as demo:
|
26 |
+
with gr.Row():
|
27 |
+
prompt = gr.Textbox(label="Enter your MarioGPT prompt")
|
28 |
+
level_image = gr.Image()
|
29 |
+
btn = gr.Button("Generate level")
|
30 |
+
btn.click(fn=update, inputs=prompt, outputs=level_image)
|
31 |
+
pass
|
32 |
+
demo.launch()
|
data/tiles/N.png
ADDED
data/tiles/Y.png
ADDED
data/tiles/cannon_bottom.png
ADDED
data/tiles/cannon_top.png
ADDED
data/tiles/flying_koopa.png
ADDED
data/tiles/ki-background.png
ADDED
data/tiles/ki-door.png
ADDED
data/tiles/ki-hazard.png
ADDED
data/tiles/ki-moving-platform.png
ADDED
data/tiles/ki-passable.png
ADDED
data/tiles/ki-path.png
ADDED
data/tiles/ki-unpassable.png
ADDED
data/tiles/mm-CMM.png
ADDED
data/tiles/mm-DMM.png
ADDED
data/tiles/mm-HMM.png
ADDED
data/tiles/mm-LMM.png
ADDED
data/tiles/mm-MMM.png
ADDED
data/tiles/mm-TMM.png
ADDED
data/tiles/mma_tiles.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6d58bb3228bcd3c653c4a58b69044588ffd6e5e4c946a860497a39d84eb60b8
|
3 |
+
size 6586
|
data/tiles/plant.png
ADDED
data/tiles/smb-background.png
ADDED
data/tiles/smb-breakable.png
ADDED
data/tiles/smb-coin.png
ADDED
data/tiles/smb-enemy.png
ADDED
data/tiles/smb-path.png
ADDED
data/tiles/smb-question.png
ADDED
data/tiles/smb-tube-lower-left.png
ADDED
data/tiles/smb-tube-lower-right.png
ADDED
data/tiles/smb-tube-top-left.png
ADDED
data/tiles/smb-tube-top-right.png
ADDED
data/tiles/smb-unpassable.png
ADDED
data/tiles/smb_enemies_sheet.png
ADDED
data/tiles/tile004 (1).png
ADDED
data/tiles/tile004 (2).png
ADDED
data/tiles/tile004.png
ADDED
mario_gpt/__init__.py
ADDED
File without changes
|
mario_gpt/dataset.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
9 |
+
|
10 |
+
from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS
|
11 |
+
|
12 |
+
DEFAULT_MODEL = "distilgpt2"
|
13 |
+
|
14 |
+
|
15 |
+
def split_given_size(a, size):
|
16 |
+
return np.split(a, np.arange(size, len(a), size))
|
17 |
+
|
18 |
+
|
19 |
+
def flip_and_transpose(arr: np.array, flip_first: bool = False):
|
20 |
+
if arr.shape[-1] > 1:
|
21 |
+
if flip_first:
|
22 |
+
return np.flip(arr, -1).transpose()
|
23 |
+
return np.flip(arr.transpose(), -1)
|
24 |
+
return arr
|
25 |
+
|
26 |
+
|
27 |
+
def join_list_of_list(str_lists):
|
28 |
+
return ["".join(s) for s in str_lists]
|
29 |
+
|
30 |
+
|
31 |
+
def characterize(str_lists):
|
32 |
+
return [list(s) for s in str_lists]
|
33 |
+
|
34 |
+
|
35 |
+
class MarioDataset(Dataset):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
39 |
+
level_string: Optional[str] = None,
|
40 |
+
context_len: int = 700,
|
41 |
+
height: int = 14,
|
42 |
+
remove_start_end_tokens: bool = False,
|
43 |
+
sample_all_indices: bool = False,
|
44 |
+
):
|
45 |
+
if level_string is None:
|
46 |
+
print(
|
47 |
+
"No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
|
48 |
+
)
|
49 |
+
level_string = FULL_LEVEL_STR_WITH_PATHS
|
50 |
+
elif ".txt" in level_string:
|
51 |
+
with open(level_string, "r") as file:
|
52 |
+
level_string = file.read()
|
53 |
+
|
54 |
+
self.character_set = set(level_string)
|
55 |
+
if "\n" in self.character_set:
|
56 |
+
self.character_set.remove("\n")
|
57 |
+
self.vocab_size = len(self.character_set)
|
58 |
+
self.sample_all_indices = sample_all_indices
|
59 |
+
|
60 |
+
def get_training_corpus():
|
61 |
+
yield list(level_string)
|
62 |
+
|
63 |
+
if tokenizer is None:
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
|
65 |
+
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
if getattr(tokenizer, "train_new_from_iterator", None) is not None:
|
68 |
+
self.tokenizer = tokenizer.train_new_from_iterator(
|
69 |
+
get_training_corpus(), 52000
|
70 |
+
)
|
71 |
+
elif getattr(tokenizer, "train_from_iterator", None) is not None:
|
72 |
+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
|
73 |
+
self.tokenizer = self.tokenizer.train_new_from_iterator(
|
74 |
+
get_training_corpus(), self.vocab_size
|
75 |
+
)
|
76 |
+
self.context_len = context_len
|
77 |
+
self.height = height
|
78 |
+
|
79 |
+
x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
|
80 |
+
self.input_ids = x["input_ids"].squeeze()
|
81 |
+
self.attention_masks = x["attention_mask"].squeeze()
|
82 |
+
if remove_start_end_tokens:
|
83 |
+
self.input_ids = self.input_ids[1:-1]
|
84 |
+
self.attention_masks = self.attention_masks[1:-1]
|
85 |
+
|
86 |
+
self.indices = self.generate_indices()
|
87 |
+
|
88 |
+
self.unique_tokens, self.unique_counts = self.input_ids.unique(
|
89 |
+
return_counts=True
|
90 |
+
)
|
91 |
+
self.weighted_unique_counts = (
|
92 |
+
1.0 / self.unique_counts / torch.sum(self.unique_counts)
|
93 |
+
)
|
94 |
+
|
95 |
+
self.token_dict = {}
|
96 |
+
string_tokens = list(self.tokenizer.decode(self.unique_tokens))
|
97 |
+
for int_token, string_token in zip(self.unique_tokens, string_tokens):
|
98 |
+
self.token_dict[string_token] = int_token
|
99 |
+
|
100 |
+
def convert_level_to_tensor(self, level: List[str]):
|
101 |
+
str_arr = flip_and_transpose(np.array(characterize(level)))
|
102 |
+
str_arr = "".join(join_list_of_list(str_arr))
|
103 |
+
|
104 |
+
x = self.tokenizer(str_arr, return_tensors="pt")
|
105 |
+
return x, str_arr
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return self.indices.shape[0]
|
109 |
+
|
110 |
+
def __getitem__(self, idx):
|
111 |
+
indices = self.indices[idx]
|
112 |
+
return self.input_ids[indices], self.attention_masks[indices]
|
113 |
+
|
114 |
+
def generate_indices(self):
|
115 |
+
out = []
|
116 |
+
for idx in range(self.input_ids.shape[0] - self.context_len):
|
117 |
+
if idx % self.height == 0 or self.sample_all_indices:
|
118 |
+
arange = torch.arange(idx, idx + self.context_len)
|
119 |
+
out.append(arange)
|
120 |
+
return torch.stack(out)
|
121 |
+
|
122 |
+
def sample_indices(self, batch_size):
|
123 |
+
out = []
|
124 |
+
for _ in range(batch_size):
|
125 |
+
start_idx = np.random.randint(0, self.__len__() - self.context_len)
|
126 |
+
indices = torch.arange(start_idx, start_idx + self.context_len)
|
127 |
+
out.append(indices)
|
128 |
+
return torch.stack(out)
|
129 |
+
|
130 |
+
def __str__(self):
|
131 |
+
str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
|
132 |
+
string = "\n".join(
|
133 |
+
join_list_of_list(flip_and_transpose(np.array(str_list), True))
|
134 |
+
)
|
135 |
+
return string
|
136 |
+
|
137 |
+
def generate_mask(self, mask_len: int, batch_size: int = 1):
|
138 |
+
mask_token = self.tokenizer("<mask>").input_ids[1]
|
139 |
+
ones = torch.ones((batch_size, mask_len))
|
140 |
+
return ones * mask_token
|
141 |
+
|
142 |
+
def apply_mask(self, level, masked_indices, mask=None):
|
143 |
+
if len(level.shape) == 1:
|
144 |
+
level = level.unsqueeze(0)
|
145 |
+
batch_size = level.shape[0]
|
146 |
+
mask_len = masked_indices.shape[-1]
|
147 |
+
if mask is None:
|
148 |
+
mask = self.generate_mask(mask_len, batch_size)
|
149 |
+
mask = mask.long().to(level.device)
|
150 |
+
masked_level = level * torch.ones_like(level).to(level.device)
|
151 |
+
masked_level[:, masked_indices] = mask
|
152 |
+
return masked_level
|
mario_gpt/level.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mario_gpt/lm.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import (
|
7 |
+
AutoModelWithLMHead,
|
8 |
+
AutoTokenizer,
|
9 |
+
GPT2Model,
|
10 |
+
GPT2Tokenizer,
|
11 |
+
LogitsProcessorList,
|
12 |
+
PreTrainedModel,
|
13 |
+
PreTrainedTokenizer,
|
14 |
+
TemperatureLogitsWarper,
|
15 |
+
TopKLogitsWarper,
|
16 |
+
)
|
17 |
+
|
18 |
+
from mario_gpt.prompter import Prompter
|
19 |
+
|
20 |
+
PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
|
21 |
+
|
22 |
+
|
23 |
+
class MarioLM:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
lm: Optional[PreTrainedModel] = None,
|
27 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
28 |
+
context_len: int = 700,
|
29 |
+
prompter: Optional[Prompter] = None,
|
30 |
+
):
|
31 |
+
self.context_len = context_len
|
32 |
+
self.lm = lm
|
33 |
+
|
34 |
+
if lm is None:
|
35 |
+
self.lm = self.load_pretrained_lm()
|
36 |
+
|
37 |
+
self.tokenizer = tokenizer
|
38 |
+
if tokenizer is None:
|
39 |
+
self.tokenizer = self.load_pretrained_tokenizer()
|
40 |
+
|
41 |
+
self.prompter = prompter
|
42 |
+
if prompter is None:
|
43 |
+
self.prompter = Prompter(self.tokenizer)
|
44 |
+
|
45 |
+
@property
|
46 |
+
def device(self):
|
47 |
+
return self.lm.device
|
48 |
+
|
49 |
+
def to(self, device: torch.device):
|
50 |
+
self.lm = self.lm.to(device)
|
51 |
+
return self
|
52 |
+
|
53 |
+
def load_pretrained_lm(self) -> GPT2Model:
|
54 |
+
print(f"Using {PRETRAINED_MODEL_PATH} model")
|
55 |
+
return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
|
56 |
+
|
57 |
+
def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
|
58 |
+
print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
|
59 |
+
return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
|
60 |
+
|
61 |
+
def sample_step(
|
62 |
+
self,
|
63 |
+
seed: torch.Tensor,
|
64 |
+
encoder_hidden_states: torch.Tensor,
|
65 |
+
temperature: float = 2.0,
|
66 |
+
):
|
67 |
+
lm = self.lm
|
68 |
+
logits_processor = LogitsProcessorList()
|
69 |
+
logits_warper = LogitsProcessorList(
|
70 |
+
[
|
71 |
+
TopKLogitsWarper(16), # number of characters
|
72 |
+
TemperatureLogitsWarper(temperature),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
with torch.no_grad():
|
76 |
+
attention_mask = torch.ones_like(seed).to(seed.device)
|
77 |
+
input_ids = seed
|
78 |
+
out = lm(
|
79 |
+
input_ids=input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
encoder_hidden_states=encoder_hidden_states,
|
82 |
+
token_type_ids=None,
|
83 |
+
)
|
84 |
+
logits = out.logits.detach()
|
85 |
+
if len(logits.shape) == 2:
|
86 |
+
logits = logits.view(1, 1, -1)
|
87 |
+
next_token_logits = logits[:, -1, :]
|
88 |
+
|
89 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
90 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
91 |
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
92 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
93 |
+
return next_tokens, encoder_hidden_states
|
94 |
+
|
95 |
+
def sample(
|
96 |
+
self,
|
97 |
+
seed: Optional[torch.Tensor] = None,
|
98 |
+
prompts: Optional[List[str]] = None,
|
99 |
+
num_steps: int = 1,
|
100 |
+
temperature: float = 2.0,
|
101 |
+
encoder_hidden_states: torch.Tensor = None,
|
102 |
+
use_tqdm: bool = False,
|
103 |
+
):
|
104 |
+
context_len = self.context_len - 28
|
105 |
+
self.lm.eval()
|
106 |
+
with torch.no_grad():
|
107 |
+
if seed is None:
|
108 |
+
seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
|
109 |
+
out = seed.to(self.device)
|
110 |
+
if encoder_hidden_states is None:
|
111 |
+
if prompts is not None:
|
112 |
+
encoder_hidden_states = torch.stack(
|
113 |
+
[self.prompter.output_hidden(prompt) for prompt in prompts]
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
encoder_hidden_states = torch.stack(
|
117 |
+
[
|
118 |
+
self.prompter(sample_prompt=True)[1]
|
119 |
+
for _ in range(seed.shape[0])
|
120 |
+
]
|
121 |
+
)
|
122 |
+
encoder_hidden_states = encoder_hidden_states.to(
|
123 |
+
self.device
|
124 |
+
) # b x 1 x hidden_dim
|
125 |
+
encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
|
126 |
+
if not use_tqdm:
|
127 |
+
bar = np.arange(num_steps)
|
128 |
+
else:
|
129 |
+
bar = tqdm(np.arange(num_steps))
|
130 |
+
with torch.no_grad():
|
131 |
+
for i in bar:
|
132 |
+
inp = out * 1
|
133 |
+
if len(out.shape) > 0 and out.shape[-1] > context_len:
|
134 |
+
diff = inp.shape[-1] % 14 # height of mario level
|
135 |
+
ctx = context_len + diff
|
136 |
+
inp = inp[:, -ctx:] * 1
|
137 |
+
next_tokens, encoder_hidden_states = self.sample_step(
|
138 |
+
inp,
|
139 |
+
encoder_hidden_states=encoder_hidden_states,
|
140 |
+
temperature=temperature,
|
141 |
+
)
|
142 |
+
out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
|
143 |
+
if use_tqdm:
|
144 |
+
bar.set_description(
|
145 |
+
f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
|
146 |
+
)
|
147 |
+
if use_tqdm:
|
148 |
+
bar.close()
|
149 |
+
self.lm.train()
|
150 |
+
return out
|
mario_gpt/prompter.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from scipy import stats
|
9 |
+
from transformers import pipeline
|
10 |
+
|
11 |
+
from mario_gpt.dataset import MarioDataset
|
12 |
+
from mario_gpt.utils import view_level
|
13 |
+
|
14 |
+
STATISTICS = {
|
15 |
+
"enemy": np.array([1.0, 3.0, 7.0]),
|
16 |
+
"pipe": np.array([0.0, 2.0, 5.0]),
|
17 |
+
"block": np.array([50.0, 75.0, 176.0]),
|
18 |
+
}
|
19 |
+
|
20 |
+
FEATURE_EXTRACTION_MODEL = "facebook/bart-base"
|
21 |
+
|
22 |
+
|
23 |
+
class Prompter:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
level_tokenizer,
|
27 |
+
prompter_model: str = FEATURE_EXTRACTION_MODEL,
|
28 |
+
use_raw_counts: bool = False,
|
29 |
+
statistics: Optional[Dict[str, Any]] = None,
|
30 |
+
):
|
31 |
+
self.prompter_model = prompter_model
|
32 |
+
self.feature_extraction = pipeline(
|
33 |
+
"feature-extraction",
|
34 |
+
model=prompter_model,
|
35 |
+
tokenizer=prompter_model,
|
36 |
+
framework="pt",
|
37 |
+
)
|
38 |
+
|
39 |
+
self.level_tokenizer = level_tokenizer
|
40 |
+
|
41 |
+
self.use_raw_counts = use_raw_counts
|
42 |
+
self.statistics = statistics
|
43 |
+
if statistics is None:
|
44 |
+
self.statistics = STATISTICS
|
45 |
+
|
46 |
+
@property
|
47 |
+
def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
|
48 |
+
thresholds = self.statistics["pipe"]
|
49 |
+
keywords = ["no", "little", "some", "many"]
|
50 |
+
return thresholds, keywords
|
51 |
+
|
52 |
+
@property
|
53 |
+
def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
|
54 |
+
thresholds = self.statistics["enemy"]
|
55 |
+
keywords = ["no", "little", "some", "many"]
|
56 |
+
return thresholds, keywords
|
57 |
+
|
58 |
+
@property
|
59 |
+
def block_thresholds(self) -> Tuple[List[int], List[str]]:
|
60 |
+
thresholds = self.statistics["block"]
|
61 |
+
keywords = ["little", "little", "some", "many"]
|
62 |
+
return thresholds, keywords
|
63 |
+
|
64 |
+
def count_pipes(self, flattened_level: str) -> int:
|
65 |
+
return flattened_level.count("<>")
|
66 |
+
|
67 |
+
def count_enemies(self, flattened_level: str) -> int:
|
68 |
+
return flattened_level.count("E") + flattened_level.count("B")
|
69 |
+
|
70 |
+
def count_blocks(self, flattened_level: str) -> int:
|
71 |
+
return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])
|
72 |
+
|
73 |
+
def _flatten_level(self, string_level: List[str]) -> str:
|
74 |
+
return "".join(string_level)
|
75 |
+
|
76 |
+
def pipe_prompt(self, flattened_level: str, level: str) -> str:
|
77 |
+
count = self.count_pipes(flattened_level)
|
78 |
+
keyword = f"{count}"
|
79 |
+
if not self.use_raw_counts:
|
80 |
+
thresholds, keywords = self.pipe_thresholds
|
81 |
+
threshold = np.digitize(count, thresholds, right=True)
|
82 |
+
keyword = keywords[threshold]
|
83 |
+
return f"{keyword} pipes", keyword
|
84 |
+
|
85 |
+
def enemy_prompt(self, flattened_level: str, level: str) -> str:
|
86 |
+
count = self.count_enemies(flattened_level)
|
87 |
+
keyword = f"{count}"
|
88 |
+
if not self.use_raw_counts:
|
89 |
+
thresholds, keywords = self.enemy_thresholds
|
90 |
+
threshold = np.digitize(count, thresholds, right=True)
|
91 |
+
keyword = keywords[threshold]
|
92 |
+
return f"{keyword} enemies", keyword
|
93 |
+
|
94 |
+
def block_prompt(self, flattened_level: str, level: str) -> str:
|
95 |
+
count = self.count_blocks(flattened_level)
|
96 |
+
keyword = f"{count}"
|
97 |
+
if not self.use_raw_counts:
|
98 |
+
thresholds, keywords = self.block_thresholds
|
99 |
+
threshold = np.digitize(count, thresholds, right=True)
|
100 |
+
keyword = keywords[threshold]
|
101 |
+
return f"{keyword} blocks", keyword
|
102 |
+
|
103 |
+
def elevation_prompt(self, flattened_level: str, level: str):
|
104 |
+
top_levels = level[:6] # elevation 8 and up
|
105 |
+
for t in top_levels:
|
106 |
+
if "X" in t or "<" in t or ">" in t:
|
107 |
+
return "high elevation", "high"
|
108 |
+
return "low elevation", "low"
|
109 |
+
|
110 |
+
def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
|
111 |
+
# Reducing along the first dimension to get a 768 dimensional array
|
112 |
+
return (
|
113 |
+
self.feature_extraction(prompt, return_tensors="pt")[0]
|
114 |
+
.mean(0)
|
115 |
+
.to(device)
|
116 |
+
.view(1, -1)
|
117 |
+
)
|
118 |
+
|
119 |
+
def dataset_statistics(self, dataset: MarioDataset):
|
120 |
+
enemy_counts = []
|
121 |
+
pipe_counts = []
|
122 |
+
block_counts = []
|
123 |
+
for i in range(len(dataset)):
|
124 |
+
level, _ = dataset[i]
|
125 |
+
str_level = self._flatten_level(view_level(level, dataset.tokenizer))
|
126 |
+
|
127 |
+
enemy_count = self.count_enemies(str_level)
|
128 |
+
pipe_count = self.count_pipes(str_level)
|
129 |
+
block_count = self.count_blocks(str_level)
|
130 |
+
|
131 |
+
enemy_counts.append(enemy_count)
|
132 |
+
pipe_counts.append(pipe_count)
|
133 |
+
block_counts.append(block_count)
|
134 |
+
d = {"enemy": {}, "pipe": {}, "block": {}}
|
135 |
+
|
136 |
+
d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
|
137 |
+
d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
|
138 |
+
d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
|
139 |
+
return d
|
140 |
+
|
141 |
+
def __call__(
|
142 |
+
self, level: torch.Tensor = None, sample_prompt: bool = False
|
143 |
+
) -> Union[str, torch.Tensor]:
|
144 |
+
device: torch.device = torch.device("cpu")
|
145 |
+
if not sample_prompt:
|
146 |
+
if level is None:
|
147 |
+
raise ValueError("Level must be provided if sample_prompt is not true!")
|
148 |
+
str_level = view_level(level, self.level_tokenizer)
|
149 |
+
flattened_level = self._flatten_level(str_level)
|
150 |
+
|
151 |
+
pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
|
152 |
+
enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
|
153 |
+
block_prompt, _ = self.block_prompt(flattened_level, str_level)
|
154 |
+
elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
|
155 |
+
device = level.device
|
156 |
+
else:
|
157 |
+
str_level = None
|
158 |
+
pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
|
159 |
+
enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
|
160 |
+
block_prompt = (
|
161 |
+
random.choice(["little", "little", "some", "many"]) + " blocks"
|
162 |
+
) # levels always have blocks
|
163 |
+
elevation_prompt = (
|
164 |
+
random.choice(["low", "high"]) + " elevation"
|
165 |
+
) # levels always have blocks
|
166 |
+
|
167 |
+
prompt_dict = {
|
168 |
+
"pipe": pipe_prompt,
|
169 |
+
"enemy": enemy_prompt,
|
170 |
+
"block": block_prompt,
|
171 |
+
"elevation_prompt": elevation_prompt,
|
172 |
+
}
|
173 |
+
prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
|
174 |
+
hidden = self.output_hidden(prompt, device=device)
|
175 |
+
return prompt, hidden, prompt_dict, str_level
|
mario_gpt/utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def characterize(str_lists):
|
9 |
+
return [list(s[::-1]) for s in str_lists]
|
10 |
+
|
11 |
+
|
12 |
+
def join_list_of_list(str_lists):
|
13 |
+
return ["".join(s) for s in str_lists]
|
14 |
+
|
15 |
+
|
16 |
+
def view_level(level_tokens, tokenizer):
|
17 |
+
str_list = [
|
18 |
+
s.replace("<mask>", "Y")
|
19 |
+
for s in tokenizer.batch_decode(level_tokens.detach().cpu().view(-1, 14))
|
20 |
+
]
|
21 |
+
return join_list_of_list(np.array(characterize(str_list)).T)
|
22 |
+
|
23 |
+
|
24 |
+
def is_flying_enemy(array, row, col):
|
25 |
+
num_rows = array.shape[0]
|
26 |
+
if row == num_rows - 1:
|
27 |
+
return False
|
28 |
+
below = array[row + 1][col]
|
29 |
+
return below == "-"
|
30 |
+
|
31 |
+
|
32 |
+
def char_array_to_image(array, chars2pngs):
|
33 |
+
"""
|
34 |
+
Convert a 16-by-16 array of integers into a PIL.Image object
|
35 |
+
param: array: a 16-by-16 array of integers
|
36 |
+
"""
|
37 |
+
image = Image.new("RGB", (array.shape[1] * 16, array.shape[0] * 16))
|
38 |
+
for row in range(array.shape[0]):
|
39 |
+
for col, char in enumerate(array[row]):
|
40 |
+
value = chars2pngs["-"]
|
41 |
+
# if char == "E":
|
42 |
+
# if is_flying_enemy(array, row, col):
|
43 |
+
# char = "F"
|
44 |
+
if char in chars2pngs:
|
45 |
+
value = chars2pngs[char]
|
46 |
+
else:
|
47 |
+
print(f"REPLACING {value}", (col, row))
|
48 |
+
|
49 |
+
image.paste(value, (col * 16, row * 16))
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def convert_level_to_png(
|
54 |
+
level: Union[str, torch.Tensor], tiles_dir: str, tokenizer=None
|
55 |
+
):
|
56 |
+
if isinstance(level, torch.Tensor):
|
57 |
+
level = view_level(level, tokenizer)
|
58 |
+
chars2pngs = {
|
59 |
+
"-": Image.open(f"{tiles_dir}/smb-background.png"),
|
60 |
+
"X": Image.open(f"{tiles_dir}/smb-unpassable.png"),
|
61 |
+
"S": Image.open(f"{tiles_dir}/smb-breakable.png"),
|
62 |
+
"?": Image.open(f"{tiles_dir}/smb-question.png"),
|
63 |
+
"Q": Image.open(f"{tiles_dir}/smb-question.png"),
|
64 |
+
"o": Image.open(f"{tiles_dir}/smb-coin.png"),
|
65 |
+
"E": Image.open(f"{tiles_dir}/smb-enemy.png"),
|
66 |
+
"<": Image.open(f"{tiles_dir}/smb-tube-top-left.png"),
|
67 |
+
">": Image.open(f"{tiles_dir}/smb-tube-top-right.png"),
|
68 |
+
"[": Image.open(f"{tiles_dir}/smb-tube-lower-left.png"),
|
69 |
+
"]": Image.open(f"{tiles_dir}/smb-tube-lower-right.png"),
|
70 |
+
"x": Image.open(f"{tiles_dir}/smb-path.png"), # self-created
|
71 |
+
"Y": Image.open(f"{tiles_dir}/Y.png"), # self-created
|
72 |
+
"N": Image.open(f"{tiles_dir}/N.png"), # self-created
|
73 |
+
"B": Image.open(f"{tiles_dir}/cannon_top.png"),
|
74 |
+
"b": Image.open(f"{tiles_dir}/cannon_bottom.png"),
|
75 |
+
"F": Image.open(f"{tiles_dir}/flying_koopa.png"),
|
76 |
+
}
|
77 |
+
levels = [list(s) for s in level]
|
78 |
+
arr = np.array(levels)
|
79 |
+
return char_array_to_image(arr, chars2pngs), arr, level
|
80 |
+
|
81 |
+
|
82 |
+
TOKENS = [
|
83 |
+
"-",
|
84 |
+
"X",
|
85 |
+
"S",
|
86 |
+
"?",
|
87 |
+
"Q",
|
88 |
+
"o",
|
89 |
+
"E",
|
90 |
+
"<",
|
91 |
+
">",
|
92 |
+
"[",
|
93 |
+
"]",
|
94 |
+
"x",
|
95 |
+
"Y",
|
96 |
+
"N",
|
97 |
+
"B",
|
98 |
+
"b",
|
99 |
+
]
|
notebooks/Sampling.ipynb
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "df85b023-cdb5-498e-8373-0fd5b7c31853",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Load Stuff"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
+
"id": "895fc851-817b-4c23-baf4-72cf73238781",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import torch\n",
|
19 |
+
"from mario_gpt.dataset import MarioDataset\n",
|
20 |
+
"from mario_gpt.prompter import Prompter\n",
|
21 |
+
"from mario_gpt.lm import MarioLM\n",
|
22 |
+
"from mario_gpt.utils import view_level, convert_level_to_png"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "markdown",
|
27 |
+
"id": "28c11f07-b604-4603-8fe3-d53874ba02a8",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"### Load Model"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 2,
|
36 |
+
"id": "6f656e57-24a6-4624-b6ed-8aa871581007",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [
|
39 |
+
{
|
40 |
+
"name": "stdout",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"Using shyamsn97/Mario-GPT2-700-context-length model\n"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"name": "stderr",
|
48 |
+
"output_type": "stream",
|
49 |
+
"text": [
|
50 |
+
"/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:1177: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
|
51 |
+
" warnings.warn(\n"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"Using shyamsn97/Mario-GPT2-700-context-length tokenizer\n"
|
59 |
+
]
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"mario_lm = MarioLM()"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 3,
|
69 |
+
"id": "1a60f6ed-42be-4d17-af15-151fa24e0f91",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"TILE_DIR = \"../data/tiles\""
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "markdown",
|
78 |
+
"id": "a7d7bd55-14d4-45a3-9539-c7c385f63070",
|
79 |
+
"metadata": {},
|
80 |
+
"source": [
|
81 |
+
"### Load Dataset (Optional)"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 4,
|
87 |
+
"id": "6c0840d0-ea5b-4111-9198-6b5a716083bd",
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [
|
90 |
+
{
|
91 |
+
"name": "stdout",
|
92 |
+
"output_type": "stream",
|
93 |
+
"text": [
|
94 |
+
"No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS...\n",
|
95 |
+
"\n",
|
96 |
+
"\n",
|
97 |
+
"\n"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"name": "stderr",
|
102 |
+
"output_type": "stream",
|
103 |
+
"text": [
|
104 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (102116 > 1024). Running this sequence through the model will result in indexing errors\n"
|
105 |
+
]
|
106 |
+
}
|
107 |
+
],
|
108 |
+
"source": [
|
109 |
+
"dataset = MarioDataset(mario_lm.tokenizer)"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "markdown",
|
114 |
+
"id": "c80a131f-c68f-475d-ab24-acd3da814c39",
|
115 |
+
"metadata": {},
|
116 |
+
"source": [
|
117 |
+
"#### View string representation of level"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 5,
|
123 |
+
"id": "2bdab45e-58cb-4bcb-8d6e-dee6c946d6fd",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"data": {
|
128 |
+
"text/plain": [
|
129 |
+
"['--------------------------------------------------',\n",
|
130 |
+
" '--------------------------------------------------',\n",
|
131 |
+
" '--------------------------------------------------',\n",
|
132 |
+
" '--------------------------------------------------',\n",
|
133 |
+
" '-------------------------------------------------o',\n",
|
134 |
+
" '--------XSSSSS---------------------------------SSS',\n",
|
135 |
+
" '--------X-----------------------------------------',\n",
|
136 |
+
" '--------X-----------------------------------------',\n",
|
137 |
+
" '-------EX--E-X---------------xxxx-?-----------xxxx',\n",
|
138 |
+
" '--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',\n",
|
139 |
+
" '---------------------------xx-[]--x---------xx----',\n",
|
140 |
+
" '--------------------------xx--[]---x-------xx-----',\n",
|
141 |
+
" 'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',\n",
|
142 |
+
" 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
"execution_count": 5,
|
146 |
+
"metadata": {},
|
147 |
+
"output_type": "execute_result"
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"source": [
|
151 |
+
"view_level(dataset.input_ids[:700], mario_lm.tokenizer)"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "markdown",
|
156 |
+
"id": "99be5b3a-c968-4fbd-a51a-f623003072c0",
|
157 |
+
"metadata": {},
|
158 |
+
"source": [
|
159 |
+
"#### Image"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "markdown",
|
164 |
+
"id": "d5614fc2-59bc-40ee-a92a-0cfd971e1ad3",
|
165 |
+
"metadata": {},
|
166 |
+
"source": [
|
167 |
+
"##### Previewing the first 50 columns of the dataset"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 6,
|
173 |
+
"id": "0d6a3bf3-d050-4760-a48e-8b8655142c67",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [
|
176 |
+
{
|
177 |
+
"name": "stderr",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/Pillow-9.1.1-py3.9-linux-x86_64.egg/PIL/Image.py:992: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
|
181 |
+
" warnings.warn(\n"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"data": {
|
186 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAyAAAADgCAIAAAB0EpUWAAAYPUlEQVR4nO3dT2wcVZ7A8fe6q9OOsWNv4mwDQXHsvRhWyybZbAA7M4qGS9QoWhGFA1oJgmwNITkwBywuHo1G8S0ckBgFZXGCJZgDYsgcMmmNOIyi7MYBokSMUFgIKAECJuGPkzh2/Kf/1B6abZz+U1Wv+3X5VdX3o9EInG93qitdzo/X5Sq551BWeCalsG3vOT09PT09PT19FPuYQi7Unp2enp6enp6ePpq92oAFAAAAVwxYAAAAmjFgAQAAaMaABQAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgWUxKtQfQ09PT09PT09M7s5KWsG1RsEWu4H6zaCkFPT09PT09PT29cy/3HMq6VHc+wPVJ6enp6enp6ekj3qudg6X07PT09PT09PT00ew5yR0AAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANAsJqXaA+jp6enp6enp6Z1ZSUvYtijYIldwv1m0lIKenp6enp6ent65l3sOZV2qOx/g+qT09PT09PT09BHv1c7BUnp2enp6enp6evpo9pzkDgAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAaBaTUu0B9PT09PT09PT0zqykJWxbFGyRK7jfLFpKQU9PT09PT09P79zLPYeyLtWdD3B9Unp6enp6enr6iPdq52ApPTs9PT09PT09fTR7TnIHAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0i0mp9gB6enp6enp6enpnVtISti0KtsgV3G8WLaWgp6enp6enp6d37uWeQ1mX6s4HuD4pPT09PT09PX3I+gdmXlZ4gOo5WEpbQ09PT09PT08fwV4IYSk/AgAAINr+8x8Olf753qHLQoiDBw8uDfgpQgAAAAVLpyshxORYjxBieHh46RcZsAAAAJRljl3OHLtc61cZsAAAALwqLl9ljl1O7+pJ7+opfj5YuYjFgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAePXH6/uEEMUrYN07dLl4BazKu+VYBt6wmp6enp6enp7eqL5ScbSqxUpawrZFwRa5gvvvJKWgp6enp6enp49av9Qfr+9zvdmz3HMo6/Ksd/4GRo2TIeiPPnT2yEj/4OiEe7y5Xwhhn3cvw+SZ9/7de2zgny89PT09vT/90YfOKjyg+dQGLGj3+sNnhRCuM1ZxuhJCjKWr/OpQJrRf5/0JAPBi218SlV9cxr+/OMndCIOjE0dG+mv9amm6AgAAgcAK1jIrrmAV1ZqxhjJ+bY15eH8CALwY31dlBWsZWcu9AVBTWoo08OO8Znz9f6okAACUM+3vRz4iBAAA0IwBCwAAQDPOwVpmnIPljPcnAMALzsFCQ0z4XNnPrwMA4IVpf3/xESEAAIBmDFgAAACacQ7WMiudg+VwNdGxdHQ/LOP9CQDwgnOwUIXztdqLH/EWZyzTrvPR7K9zHSwAgBem/f3IR4TLz8udcGr9cQIAAAPFpFR7AL3m3vN9BqM5Yxn350VPT09Pb2RvGrn3taxti4ItcgVh2261FElL0GvsX3/OrM+MTcP7k56enp4+iH+fqp3kLqX7i6RX6o8+dPbISP/g6IR7vLlfKJ70beDrpaenp6enj0Kvdg6W0rPTe+wHRydqXcO9xPsniY1vDz09PT09PX2DPSe5G8F5xqpvugIAAMuFyzSYwss6FgAACARWsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAM66DZQqHq4mOpcVQxs9tAQAADWEFywjO12ofyoixtG/bAgAAGsWAtfy83AmHGQsAgACJSan2AHrNvef7DBZnLOO2n56enp6enr6ClbSEbYuCLXIF95tFSyno9fZKhjJi7+NmbT89PT09PT19ZW/NZ12iMvR6e/v8xJGR/sHRCdeyuNZl2vbT09PT09PTV1I7B8t1ZKOvox8cnTgy4vJBofdPEhvfHnp6enp6evoGe05yN4LzjFXfdAUAAJYL18EyhZd1LAAAEAisYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZlwHyxQOVxMdS4uhjJ/bAgAAGsIKlhGcr9U+lBFjad+2BQAANIoBa/l5uRMOMxYAAAESk1LtAfSae8/3GSzOWMZtPz09PT09PX0FK2kJ2xYFW+QK7jeLllLQ6+2VDGXE3sfN2n56enp6enr6yt6az7pEZeid/dONl8u+surk70v/PL39d2W/ap+fODLSPzg64frMxbUu07b/wl2/UXr+V//trNLr3XNI7QWY9n6gp49yv+XPCe/x4OjpZ97b2tTtoaf3s1f7KULXkY2+zNLppPivlTPK4Kj7jOX9k0Tn7VHtvWy/6vOb/Hrp6en19o/t6vHYy80DSv9BZebrpacv9Zzk3kRl04nDF4szR63nqW/aaJz37Vdl5usFfBYv5NbMXl01P7X69rXUrStrZq8mc/MB6j26Z+hygwEgzHv/u/ZcB6tZHAYRh3WsJm+UAtXtV2Xa6wX8t/r2tf2ZJ7/p2tgxO5maunCz7b6TD+57vzedjScD0Xv07dhPi1gnjt0xSJUWt0oB4MC0979rzwpWU6w6+ftnT8w8e2Km8peKX9eyDtQ8Qd9+IBAS+YXuyTObLr6xkGj7oO/plsXp3aee77t6Nii9kuJ0NZT56X+iYt4CnJn2/nftGbD0Kw4fhx9rs9/7a9mM8uyJGfu9vx5+rE1o+qytGYK+/UCA5Kz4xfXpV3aMv7N1ePzRw4nc7Uc+fStAfR2+HespLVnx4SCUmPb+d+75iLDpqq4DBUjQtx8wWdZqubJ2oxCiIOPTK7tutPd0zE4GqAf8ZNr737lnBUuzynWd4pJPcUHINV52Qd9+IFgSufnu786tmp+6a3E6Nf1F561LU+3dAeq9e2xXT+meqnw4iPqY9v537lnBaq6yoaRyRjFc0LcfMFw8n98weeqJMwe+7+zd/Nnbt1rXne/dGaBeSfGTwdLJWIAq097/zj0DVhP914HdDv9qvqBvP2C+rLXiq7sHUtc/uf/LzExr6t0tL364fnuA+kbww4NQZdr737lnwNJsevvvln5w9uvf/qksWDqmNH6xA+2Cvv1AsOTjyc/XbTu+aX/H3I8zyc5FtwsimNbX4bFdPZzbjvqY9v537hmw9CuNHS/84l8qf/XXv/3TS//9UeXXHa6uOZb2dTm9vu1XZc7rBZZdXlpTrang9kpOHLs8lla4wjtQxrT3f62ek9yXwfDw8PDw8NKvOF+7fCgjxtJN3iYVlduvKlivF2gSW8YWEu25mNcb9pnWN4JFLKgy7f3v2jNgNcvw8LB8eEfZF+XDO0pfLM0oXu4M4//M4X37VZn5egH/zSXazvU99fWaBwLae1Q2Sy39ccKqAVCVae9/114+82pW6RaGUqrd8jBq/T/Pvlz8h9L8UfrJu8p55eDBg0pjylhanN7Z3D8v1e3/uO03Ss8/vk/hP459eL309PTN67f9JeH9o8B7hi7z9xF9mHoraQnbFgVb5Aruj5RS0Dv3JaXhST68o+rVDQ4ePOjydBWGMmLv42Ztf93P74UPr5eenr55/VBGiIzC6pRp209P30gv9xzKulR3PsCo8TAE/dGHzh4Z6R8cnXCPN/cLIezz7mUjlG7APDh6WgjFoUkIpddr2vtz4LjCCtzg6OkjIwOB7p95b6v33sDjK+j9688193Qo044venqTe9Xv/2o/Rai0NfQe+8HRCdeZo3TeUtUBqNYZS/V93fuSvtw8oPr89vkJpderxJ8/L9X9E+he6S9gM4+voPcvfaz2KO9eUDzVysz9Q0/vZ6/0/ZPLNBjBeeaob9poxD1Dl52vAegaODPt9apS3T9B7+FRvJDrnPshG19hFbKJ/GIulphJdi5YLY30o3+74yEjvyp/krLAT6qvF1iqGceLD7x//2TAMkVx5nDNal0gStfXi0rvnrJbhpWG91JQ3/MLz6/XTN73Tzh6eLT69rX9mSe/6drYMTuZmrpws+2+kw/ue783na1xuUIvfWmichikio3zpDXyKzH6t5//XwjxgtqLq3P7gVqacbz4wPv3TwasgCl99Kb3Y8Gqym4ZNpYWJ45dLlsgrXt7QsDL/glTD1eJ/EL35Jl/nProfzf8xwd9T//rpT/vPvX8zda1H63bpqVvRNl0pYWf24/wMfl4ceXl+yfXwYK7b8d6SiM5V6yppLp/gt7DQc6KX1yffmXH+Dtbh8cfPZzI3X7k07c09nVbuoKlkW/bj1Ay9njxzuH7JwMWAGiTtVqurN0ohCjI+PTKrhvtPR2zkxr7ujVjBUv4uP0IJWOPFy34iDBglp7b1IzzsZZ+nLf0nqxlHzbXvT2Do9WbIPKyf8LUw4tEbr77u3Or5qfyMSs1/UXnrUuX7v2Fxr5uledgaeHb9iOUjD1eXHn5/smAFTBNPe+qckIqrnyWfdjcyPaEjOv+CVkPV/F8fsPkqSfOHPi+s3fzZ2/fal13vnenxr5uTVrB8m37EUrGHi9euH7/ZMCCAn64zJnq/gl6j0pZa8VXdw+krn9y/5eZmdbUu1te/HD9do193Zq0guXb9iOUjD1e6lD5/ZMBC+6WroWikur+CXoPB/l48vN1245v2t8x9+NMsnPR7QfIVfu6NWkFy7ftRygZe7x45/D9kwHLFA5X1xxL/7z86M91sMqcOHa51hXA635+j683EBz2Tyh7uMpLa6o11by+Dk1awSryYfsRYgYeL945fP9kwDKC87XLi6c0FWcOP6+DVabq5Wvr2x7vrzdAVK9+HvQelWwZW0i052Jeb1im2jeiGdOVn9uP8DH5eFFV9fsnl2lYfl7uDOPz5TrLFjwf29VTNu408omSga9Xler+CXoPj+YSbef6nvp6jdeb/Kn2jWjGdbD83H6Ej8nHiwPv3z8tA29YHa3e8333/Jw5KifxWvdaUaX6ek8b9udVpLp/gtsbd7yY3d9Y2fXmwAGNvai4JXPV+9t4uenNCxX/X7k9rhp8varPTx/uXvvxYtr3f7n3taxti4ItcgX330lKkbQEvcb+9edMXO00h2nvz6j9eZm2/6PWH91rCSGklC6pOtu2hRDPjeWMer309Cb3qt//5Z5DWYXa7HGYnp6ePkx96Rv6Sx//9MXiB3xl51GVzqyq9fFf2dlXpVUxvv/T0zevVzsHS3UxjZ6enp7en96BlrOvTHu99PSG9/wUIQBUFy/kOud+yMZXWIVsIr+YiyVmkp0LVouu3jdNujqDKmP3D7QIzfGiCwMWAFS3+va1/Zknv+na2DE7mZq6cLPtvpMP7nu/N52tcXlD1d43zbsClhJj9w+0CM3xoguXaQCA6hL5he7JM5suvrGQaPug7+mWxendp57vu3pWV+8bE6YrYfD+gRahOV50YcACgJpyVvzi+vQrO8bf2To8/ujhRO72I5++pbH3h/YrYNXNzP0DXcJxvOjCgAUANWWtlitrNwohCjI+vbLrRntPx+ykxt4fhqxgCVP3D3QJx/GiCwMWANSUyM13f3du1fzUXYvTqekvOm9dmmrv1tj7w5wVLDP3D3QJx/GiCye5A0BN8Xx+w+SpJ84c+L6zd/Nnb99qXXe+d6fG3h/mrGCZuX+gSziOF11YwQKAmrLWiq/uHkhd/+SXf/+DEOLdLS9+uH67xt4f5qxgmbl/oEs4jhddWMECgJry8eTn67Yd37S/Y+7HmWTnotsPkKv2/jBnBcvM/QNdwnG86MIKFgC4yEtrqjXl/bu/at9s5qxgFZm2f6BX0I8XXRiwAKA6W8YWEu25mNc7vKr2vjFkBcvY/QMtQnO86MKABQDVzSXazvU99fWaB9zTunrfGLKCZez+gRahOV50sQy8ATU9PT29Cf2NlV1vDhxoXu+bWitYhu/PZm8Pvd5+2Y8X03oraQnbFgVb5Aruj5RS0NPT09P70+tS616Epr1eevow9dZ81iUqQ09PT0/vT69LrRUs014vPX2YerVzsFQXt+np6enp/ekdaDkHy7TXS09veM91sABERbyQ65z7IRtfYRWyifxiLpaYSXYuWC26emMZ8lOEqkKz/wMqsseLLgxYAKJi9e1r+zNPftO1sWN2MjV14WbbfScf3Pd+bzpb4wI8qr2xap2DZbjQ7P+AiuzxoguXaQAQFYn8QvfkmU0X31hItH3Q93TL4vTuU8/3XT2rqzdWEKcrEaL9H1CRPV50YcACECE5K35xffqVHePvbB0ef/RwInf7kU/f0tibyZDrYNUhHPs/uKJ5vOjCgAUgQrJWy5W1G4UQBRmfXtl1o72nY3ZSY2+mgK5gibDs/+CK5vGiCwMWgAhJ5Oa7vzu3an7qrsXp1PQXnbcuTbV3a+zNFNwVrHDs/+CK5vGiCye5A4iQeD6/YfLUE2cOfN/Zu/mzt2+1rjvfu1Njb6bgrmCFY/8HVzSPF11YwQIQIVlrxVd3D6Suf/LLv/9BCPHulhc/XL9dY2+m4K5ghWP/B1c0jxddWMECECH5ePLzdduOb9rfMffjTLJz0e0HyFV7MwV3BSsc+z+4onm86MIKFoDIyUtrqjXl/bu/am+a4K5gFQV9/wdd1I4XXRiwAESFLWMLifZcLNGk3lgBXcEKzf4PqMgeL7owYAGIirlE27m+p75e80CTemMFdAUrNPs/oCJ7vOhiSal2C0N6enr6gPY3Vna9OXCgeb3q9vim1gpWyP68mr09UesDd7yY1ltJS9i2KNgiV3B/pJSCnp6ent6fXpda9yI07fXS04ept+azLlEZenp6enp/el1qrWCZ9nrp6cPUq52Dpbq4TU9PT0/vT+9AyzlYpr1eenrD+/LrYMULuc65H7LxFVYhm8gv5mKJmWTngtVS6yno6enpw9qHRkB/ilCVae+fqPUoUz5grb59bX/myW+6NnbMTqamLtxsu+/kg/ve701na1zQgp6enj6sfWjUOgcrZEx7/0StR5nyjwgT+YXuyTObLr6xkGj7oO/plsXp3aee77t6ttbj6enp6cPah0YUpith3vsnaj3KVDkHK2fFL65Pv7Jj/J2tw+OPHk7kbj/y6VsOT0FPT08f1j4cAnodrDqY9v6JWo+lqgxYWavlytqNQoiCjE+v7LrR3tMxO+nwFPT09PRh7cMhIitYwrz3T9R6LFVlwErk5ru/O7dqfuquxenU9Bedty5NtXc7PAU9PT19WPtwiM4Klmnvn6j1WKr8JHchRDyf3zB56okzB77v7N382du3Wted793p8BT09PT0Ye3DITorWKa9f6LWY6kqA1bWWvHV3QOp65/c/2VmpjX17pYXP1y/3eEp6Onp6cPah0NEfopQmPf+iVqPpaoMWPl48vN1245v2t8x9+NMsnPR7Qcy6enp6cPah0NEpith3vsnaj2WqjJgFeWlNdWa8v5E9PT09GHtgy46K1hFpr1/otajqPwkd1vGFhLtuVjC4+Pp6enpw9qHRkSmK9PeP1HrUaZ8wJpLtJ3re+rrNQ94fDw9PT19WPvQiMhPEZr2/olajzLymVezSrcwlFLtlof09PT09PX14/t+Wjx46eOfvlIcksrWokqrU7VGqLIVrBf+/29Mvv/T0zevt5KWsG1RsEWu4P5IKQU9PT09vT+9LrXOwTLt9dLTh6mX9vmJIyP9g6MTLq0QcnO/EIKenp6enp6efs+hrGv580Oav7x09KGzRu2fmBBicHTiyEi/l7qInp6enp6ent47pWmp7t6o/RPz8pjKvUlPT09PT09Pbxpz9s/P18HyMpd5/z3o6enp6enpo9CbxpD9U+VmzwAAAGgEAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmv18HSyHq4eNpcVQpvyL9PT09PT09PSmMWT/xFxrIcRQRoylvT47PT09PT09fUR605izf2KudeVj6Onp6enp6elNY9T+ka7pUqprg/T09PT09PSh7E/vzCrdkllKtVs4q/bj+xLeYx/2j9qABQAAIITY+1rWtkXBFrmC+yQkpUhaoqn9688pDFg++D9omqPgoC0DtAAAAABJRU5ErkJggg==\n",
|
187 |
+
"text/plain": [
|
188 |
+
"<PIL.Image.Image image mode=RGB size=800x224>"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
"execution_count": 6,
|
192 |
+
"metadata": {},
|
193 |
+
"output_type": "execute_result"
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"img = convert_level_to_png(dataset.input_ids[:700], TILE_DIR, mario_lm.tokenizer)[0]\n",
|
198 |
+
"img"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "markdown",
|
203 |
+
"id": "28a7e683-a9a2-4321-b21a-807daf7aa744",
|
204 |
+
"metadata": {},
|
205 |
+
"source": [
|
206 |
+
"#### Set device"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": 7,
|
212 |
+
"id": "7a6f684a-63a9-4a34-9a57-fd6aa84375a0",
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"device = torch.device('cuda')\n",
|
217 |
+
"mario_lm = mario_lm.to(device)"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "markdown",
|
222 |
+
"id": "3869772f-e3a6-43d4-94ee-40364028bea8",
|
223 |
+
"metadata": {},
|
224 |
+
"source": [
|
225 |
+
"## Generating Levels"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": 45,
|
231 |
+
"id": "1e7589f2-2b48-4174-9fc7-7e7de7ff3615",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"prompts = [\"many pipes, many enemies, some blocks, high elevation\"]"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "markdown",
|
240 |
+
"id": "aa0a437f-4123-44b2-b08f-985f60165fb2",
|
241 |
+
"metadata": {},
|
242 |
+
"source": [
|
243 |
+
"##### We generate 1399 predictions for an even 1400 output (including the input seed which is just a single block). Mario Levels have height of 14, so we generate 100 columns."
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": 46,
|
249 |
+
"id": "766362fb-8b90-43a4-b405-17fed2342d31",
|
250 |
+
"metadata": {
|
251 |
+
"scrolled": true,
|
252 |
+
"tags": []
|
253 |
+
},
|
254 |
+
"outputs": [
|
255 |
+
{
|
256 |
+
"name": "stderr",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"shape: torch.Size([1, 685]), torch.Size([1, 1400]) first: \n"
|
260 |
+
]
|
261 |
+
}
|
262 |
+
],
|
263 |
+
"source": [
|
264 |
+
"generated_level = mario_lm.sample(\n",
|
265 |
+
" prompts=prompts,\n",
|
266 |
+
" num_steps=1399,\n",
|
267 |
+
" temperature=2.0,\n",
|
268 |
+
" use_tqdm=True\n",
|
269 |
+
")"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": 47,
|
275 |
+
"id": "777f94cf-a765-4f7a-a7b4-223c29680e17",
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [
|
278 |
+
{
|
279 |
+
"data": {
|
280 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAABkAAAADgCAIAAADUoj0kAAAwcElEQVR4nO3dXYxUVaLo8bW7qqjutpvuAJ5SMbT0U8vJmKYvx4+GuSHj5KRtDg9jEOOZHEXo+EHfhJmEHpMJxkzo3AfxwROPGGYa7MTxgVHHBw6t4WFCmMuHIgw3BqPoxQ+0BZUW+4P+qKre92FjWVRX79qraq9da+39/4UY6P7XZlXXrtqL5a5d1qbdaeGZZQnb9p7T09PT09PT09PT09PT09PT09NX2tdI5EJu6/T09PT09PT09PT09PT09PT09JX3cgtYAAAAAAAAQMBYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1mosS+4G9PT09PT09PT09PT09PT09PT0QfbxZFzYtpi1RWZW2HbprdPT09PT09PT09PT09PT09PT0wfZW5t2p0tU19+g5Ebp6enp6enp6enp6enp6enp6el97OWugSW1dXp6enp6enp6enp6enp6enp6+sp7LuIOAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACt1ViW3A3o6enp6enp6enp6enp6enp6emD7OPJuLBtMWuLzKyw7dJbp6enp6enp6enp6enp6enp6enD7K3Nu1Ol6iuv0HJjdLT09PT09PT09PT09PT09PT0/vYy10DS2rr9PT09PT09PT09PT09PT09PT0lfdcxB0AAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFqrsSy5G9DT09PT09PT09PT09PT09PT0wfZx5NxYdti1haZWWHbpbdOT09PT09PT09PT09PT09PT08fZG9t2p0uUV1/g5Ibpaenp6enp6enp6enp6enp6en97GXuwaW1Nbp6enp6enp6enp6enp6enp6ekr7+NytwAAABEQm800T36Xji2Iz6YT2ZlMTWI82Twdr61WDwAAgIhjAQsAABRadPVS79BDXy1pb5oYTo2c/aHh1sN3bH2ntTsdS1alBwAAQMTJvYUQAABEQSI73TJ8fOW5V6YTDe+2PVI7M7rhyLa2iyer1QMAACDiWMACAABFZOKxc8u6X+gafOPOvsF79yQyV+/5aH8VewAAAEQZC1gAAKCIdLz2wo3tQohZKzZat+RK4/KmieEq9gAAAIgyFrAAAEARicxUyzenFk6N3DAzmhr9rHns/EhjSxV7AAAARBkXcQcAAEXEstnbho88cHznt82tHR+/Nla/9HTr+ir2AAAAiDIWsAAAQBHp+IIvblqd+v7D2z8fGq9PHVr11Jlla6vYAwAAIMpYwAIAAEVkY8lPlq45sLK3afLyeLJ5Jpasbg8AAIAoYwELAADMK2vFR+pT+vQAAACIJi7iDgAACtlWzXSiMVOT0KQHAABAxLGABQAACk0mGk61Pfzl4hWa9AAAAIg469GX0rYtcwNL0NPT09PT09PT09PT09PT09PTB9bHk3Fh22LWFpnZ0re0LEFPT09PT09PT09PT09PT09PTx9kb23anS5RXX+Dkhulp6enp6enp6enp6enp6enp6f3sZe7BpbU1unp6enp6enp6enp6enp6enp6Svv43K3AAAYIjabaZ78Lh1bEJ9NJ7IzmZrEeLJ5Ol5rSg93UXu8dBsPAAAAAsYCFgCE06Krl3qHHvpqSXvTxHBq5OwPDbcevmPrO63d6VjSiB7uovZ46TYeAAAABEzuLYQAAFMkstMtw8dXnntlOtHwbtsjtTOjG45sa7t40pQe7qL2eOk2HgAAAASMBSwACK1MPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iHu6g9XrqNBwAAAEFiAQsAQisdr71wY7sQYtaKjdYtudK4vGli2KAe7qL2eOk2HgAAAASJBSwACK1EZqrlm1MLp0ZumBlNjX7WPHZ+pLHFoB7uovZ46TYeAAAABImLuANAaMWy2duGjzxwfOe3za0dH782Vr/0dOt6g3q4i9rjpdt4AAAAECQWsAAgtNLxBV/ctDr1/Ye3fz40Xp86tOqpM8vWGtTDXdQeL93GAwAAgCCxgAUAoZWNJT9ZuubAyt6mycvjyeaZWNKsHu6i9njpNh4AAAAEiQUsAAi5rBUfqU+Z28Nd1B4v3cYDAACAYHARdwAIJ9uqmU40ZmoShvZwF7XHS7fxAAAAIGAsYAFAOE0mGk61Pfzl4hWG9nAXtcdLt/EAAAAgYNajL6VtW+YGlqCnp6enp6enp6enp6enp6enpw+sjyfjwrbFrC0ys6VvaVmCnp6enp6enp6enp6enp6enp4+yN7atDtdorr+BiU3Sq9Vv2L8+YIvLjz8h9zvR9c+U/Ddszf8JlLb1+3xoqenp/er33fXyb07Orf0Hysdd3QKIezTpct8j574F6nx6Pbzoaenp6enp5/blzF/kOpZf6CvpJf7FEKprdNr2Oev/jh/nLsGFOXt09PT04ep39Jfek7pzCaFEHt3dM79bs+QGOgucqueIbkJqJ4/H3p6enp6evq5vdT8oYxedjz09Lmei7hHSMHqj8sXo7l9GCc2m1k8cXHh1Miiq5dSYxcWT1xMZqYM6mXpNh7VTB+/LEWPrzOnnG8j5c0mjRC1/QcAAB/Jzh8iO99AwOTOwIK5XBZ6KjyPKRzbh4kWXb3UO/TQV0vamyaGUyNnf2i49fAdW99p7U7Hkkb0ut1f3Zg+flnqHl/3OWVOz5Dc1zUXtf0HAAB/eZw/lN0DZeAMrEhYePgPjx8cf/zg+NxvOV+v8Dwm07cPQyWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSlF6WbuNRzfTxy6r64zvQfe1X/u/zfxV8XXNR238AAABCjwWs8HMWd/asa7BPvF2wBvT4wXH7xNt71jWICt6LZ/r2YbRMPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iXpdt4VDN9/LKi9viqxs8HAAAgTHgLYeQUPY+J7cNQ6XjthRvbhRCzVmy0bsmVxuVNE8MG9bJ0G49qpo9fVtQeX9X4+QAAAIQJC1ghN/e8JPvE23/cuUEI8djTr1t3dxXG634Tqe3DdInMVMs3pxZOjWRr4qnRz5rHzp+/5ecG9bJ0G49qpo9fVnUf3/xrXXm5Htamfyv7rwpI1PYfAACAcGMBK1rsE2+7/JHtwzixbPa24SMPHN/5bXNrx8evjdUvPd263qBelm7jUc308cuq7uNb9LJWPUPzfl1/Udt/AAAAwo0FrAhxTlya749sHyZKxxd8cdPq1Pcf3v750Hh96tCqp84sW2tQL0u38ahm+vhlRe3xVY2fDwAAQJiwgBVyo2ufyX8X3mNPv14Q5C8Dja59Jmrbh+myseQnS9ccWNnbNHl5PNk8E0ua1cvSbTyqmT5+WVF7fFXj5wMAABAmLGCFX25ZZ/vPfzb3u489/fpzf38/yttHCGSt+Eh9ytxelm7jUc308cvy/fG1Ojrn+9ZA909vBvRy3SsTRW3/AQDAFx7nD2X3QBlYwILo6+sTQuzatYvtwyy2VTOdaMzUJAztZek2HtVMH78sRY+vy2xS/HiJK2dOmbvWlct1r/K//n88DrRKorb/AADgI+/zh/J6oDw11R4AAtLX11fwmX1CCOvurtwXnWWgyG4fJppMNJxqe/jLxSsM7WXpNh7VTB+/LBWPr/ts0jHfcpXporb/AADgF9n5Q5TnGwhY3LKEbUvcgN6svvDmd3flPrlv7npQBLdPb3R/pW7Jn1fvNLcvSfPxqO5NH79s7//j62E26ShvThmyn6fq8dDT09PT05vRS84fZPujut1ferP6J/6Utm0xa4vMbOlbWpZIxgW9Qf2K8edzf8ydo+SsARUsADlvwft/zb+J1PZ1e7zo6enp/epfflLtu+eYP9DT09PT04evZ/5Ar3NvbdqdLlFdf4OSG6VX2qt+QTEd+zN9mHrVz3eeL2b1+Qv6jvwPaZ37Mazb7129d0fnlv5jpTfe0SnYH+jp6enp6ektse+uk1LzB/v0MeYb9IH1chdxl9o6vaL+uQ/kbuXd9hVh2L53ej6+9PT5Pc8X+vn6/NUr549z17C29JeeU3o/8999PPT09PT09PTh6GXnD8w36APr+RRC8/T/7bo/7vhFiSBq24e5YrOZ5snv0rEF8dl0IjuTqUmMJ5un47V+9brxMn6eL97ptj8o3Z8LVq9yX5RdwypvNhlKuu0/AABUkez8gfkGgsEClnly/4J1+Yer07j/y3bHL0T/3376rxBieyi2D3Mtunqpd+ihr5a0N00Mp0bO/tBw6+E7tr7T2p2OJX3pdeNl/DxfvNNtf1C3Pxddvcp9a741rArvTujptv8AAFBdsvMH5hsIQE21B4CqKfjXLNtH1SWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSr143QY4/Cs8X3fYHRfvzwsN/ePzg+OMHx+duwfm6y/IWXOi2/wAAAKAAC1jRlX9GBtuHJjLx2Lll3S90Db5xZ9/gvXsSmav3fLTfx143gY0/Is8X3fYH3/dnZ3Fqz7oG+8TbBWtYjx8ct0+8vWddg3A9RQsudNt/AAAAkI+3EEaX6WdIReGMkghKx2sv3NguhJi1YqN1S640Lm+aGPax101g44/I80W3/UH1/lz0PCyUTbf9BwAAAPk4Ayu6TD9DKiJnlERNIjPV8s2phVMjN8yMpkY/ax47P9LY4mOvm8DGH5Hni277g7/789zzqpxTrpwTsgq+xUlYZdBt/wEAAEA+zsCKLtPPkIrIGSVRE8tmbxs+8sDxnd82t3Z8/NpY/dLTret97HUT2Pgj8nzRbX9Quj8XLFrNXcOCLN32HwAAAORjASu65n4qGdtH1aXjC764aXXq+w9v/3xovD51aNVTZ5at9bHXTWDjj8jzRbf9Qd3+/MedG1z+iPLotv8AAAAgHwtY0WX6GVJR+Nd4BGVjyU+Wrjmwsrdp8vJ4snmm1AfYy/a6CWz8EXm+6LY/+Ls/j659Jv+NgY89/XrBzfOXsUbXPlPx8CNHt/0HAAAA+VjAii7Tz5CKyBkl0ZS14iP1KXW9bgIYf6SeL7rtDz7uz7llqe0//9nc7z729OvP/f39uV+3Ojrn+7sGukXPkPehRYJu+w8AAFUhO39gvoEAcBH36DL9DKno/Gs8OmyrZjrRmKlJKOp1E+T4o/B80W1/qMr+3NfX19fXl/8Vl9mkEKJnSAx0V/IXhodu+w8AAFUkO39gvoFgsIAVXaZ/SmBEPlUtUiYTDafaHv5y8QpFvW6CHH8Uni+67Q/q9ue+vj7r7q6CL1p3d+W+mFvDcp9NOphTOnTbfwAAqBbZ+QPzDQTGevSltG3L3MAS9FXsB7fyP4fdsD/Th6lX/Xzn+WJW/88Tzzu/ya1P5T55cO561q5duwpOxXI30C2Ormd/oKenp6enj3ovNf+UfW8g8w36CvuaZFwk4yIRE5blaev01e1t27alHmHP7Dzmbl+3x4uevpKe5wt9fp+za9eua1+cs25VEHjXM6Td/aWnp6enp6cPvpcie2Ur5hv0FfbxqXTpLh99dXvrxwf2uQ+ufcV5Q1DBdW1yV7qZ7+1CBVfD2b7ipy0bvf1Nu+V+oLo9vvT0+VQ/33m+VLd/6X+c3Lujc0v/sZKlc2a+ffpYrt9++rpb2df/cW7vZfu6/Xzo6enp6enpg+9l5w/MN+iD7OWugSV7KgB9dXsXqq+GY8T2dXu86Okr6V3wfNG239J/bO+OEpeNyL+uhOreOz1/nvT09PT09PSV98w36LXtuYh7RKn+PDLTtw9fxGYziycuLpwaWXT1UmrswuKJi8nMlI89POL5UpTq/dNj7z7nmzvbU93DI16vAABeaDLfkMV8A3qKV3sAqI7cGRmK/k1r+vbhi0VXL/UOPfTVkvamieHUyNkfGm49fMfWd1q707GkLz084vlSlOr903vv5f9bBtnDC16vAABe6DPfkMV8AxriDKyIMv0MKf41boREdrpl+PjKc69MJxrebXukdmZ0w5FtbRdP+tXDI54vRaneP9mfw43HFwDgBfMNwEcsYEWUEdeoquL24ZdMPHZuWfcLXYNv3Nk3eO+eRObqPR/t97GHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkjHay/c2C6EmLVio3VLrjQub5oY9rGHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkhkplq+ObVwauSGmdHU6GfNY+dHGlt87OEFz5f5qN4/2Z/DjccXAOAF8w3AL1zEPaJMP0OKM0pMEctmbxs+8sDxnd82t3Z8/NpY/dLTret97OEFz5f5qN4/2Z/DjccXAOAF8w3AL5yBFVGmnyHFGSWmSMcXfHHT6tT3H/7P//tfQohDq546s2ytjz284PkyH9X7J/tzuPH4AgC8YL4B+IUzsCLK9DOkOKPEFNlY8pOlaw6s7G2avDyebJ4p9YG+sj284PkyH9X7J/tzuPH4AgC8YL4B+IUFrIjKnZGh6N+0pm8f/spa8ZH6lLoe7ni+uFO9f5bsrY7O+b410C16hoLuIYXXKwCAF1Wfb8hivgEN8RbCiDL9DCn+NW4E26qZTjRmahKKenjE86Uo1funx95ltieE6BkSA92B9vCI1ysAgBeazDdkMd+AnljAiijTr1HFNX2MMJloONX28JeLVyjq4RHPl6JU759eevfZniN/zqe6h3e8XgEAvNBhviGL+Qa0FbcsYdsSN6Cvbu8X08+Qmm/7uj1eEe+v1C358+qd6nrZ8RjX+4XnS9Fe9f5Zuvcw23M4cz7V/VG9Hy/del6v6Onp6em99NWfb8iOn/kGvcZ9TTIuknGRiAnL8rR1+ur2fjH9DKn5tq/b40VPX0nvF54vevZSZK8cUUav28+Hnp6enp6ePvheCvMN+oB7a9PudOkw7wZaLb+V0a8+IPH24C39R/fuWC3VP3riTqnx6DZ+03vVP396+iD7l5+89nx/7oNrX3QWoQrOpcqdXTXfknHBGVjbfzzHPGqv/6r7fXed3Lujc0v/sdJxR6cQwj59TLe+ZJnv0RP/4j3W8PGip6enp6fXoZedP6iev6mezzD/pK+kl/sUQqmta9uvu3+5x97qWD3QLddLPSE1HL/pfQA/f3p6PXsXvnwKoW73V89+S3/pOVz+mfZa9Xt3FHkLwHxXrOgZkpuA6vl40dPT09PT69DLHt91G49u46cPcR/Ri7jf3POpVCDbq6Z6/Kb3UCQ2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87KEIn0LoC4/7szOHm28jc2dvuvVQhNdDAAgHRfPhso/Xio4vzDegJ7kzsELj64FrJ/Uc/Ot1Cx+5k31yQXm9aqrHb3oPRRZdvdQ79NBXS9qbJoZTI2d/aLj18B1b32ntTseSvvRQxJczsOB9f3afw82lST/fZSxkL2+Bong9BIBwUDcflj2+lzce71TPT4AyRPQMLIezGtIzdO2XmLM+UmGvmurxm97Dd4nsdMvw8ZXnXplONLzb9kjtzOiGI9vaLp70q4cirF75IvT780D3tV/5v8//VfB1SAn9/gMAEaHbfJjjCyIl0gtYOV8PLM+dwuPlzWiyvWqqx296Dx9l4rFzy7pf6Bp8486+wXv3JDJX7/lov489VFD9qaDRwf6MSrD/AEA46DYf5viC6GABC4CEdLz2wo3tQohZKzZat+RK4/KmiWEfe6jAGVh+YX9GJdh/ACAcdJsPc3xBdET0GliOdfcvz52/4+XNaLK9aqrHb3oPFRKZqZZvTi2cGsnWxFOjnzWPnT9/y8997KEC18DyS7j35/xrXXm5Htamf1M7nvAJ9/4DANGh23yY4wuiI9ILWOLHi3/nLq7ke6+a6vGb3sN3sWz2tuEjDxzf+W1za8fHr43VLz3dut7HHiqweuWXcO/PRS9r1TM079chK9z7DwBEh27zYY4viI6oL2AVkP0wO90+/E71+E3vUbl0fMEXN61Off/h7Z8PjdenDq166syytT72UIEzsPzC/oxKsP8AQDjoNh/m+ILoYAFLiOvfm6aiV031+E3v4aNsLPnJ0jUHVvY2TV4eTzbPlPqAXtkeKrB65Rf2Z1SC/QcAwkG3+TDHF0QHC1hCCHHwr58OdIt193s9nUe2V031+E3v4busFR+pT6nr4S/OwPJXyf3Z6uic71sD3UXefKdJ7+W6V6gcr4cAEA6+z4dlj+8Vjqck1fMToAx8CuFPZE/q0e0kINXjN71H5WyrZjrRmKlJKOqhCKtXvvC4P7vM3kSxS0rp0w90X/uV//v8XwVfhxReDwEgHBTNh2WP72WPxyPV8xOgPBFdwCpY+1h3//KCJeGCQLZXTfX4Te+hyGSi4VTbw18uXqGohyK5M7BQCS/7s/vszZE/h9Othzq8HgJAOKiYD1dyvFZxfGG+AW3FLUvYtsQNTO8dcy/+XfAV9z+6BKaP3/Ret/0tZP2VuiV/Xr1TXS87ntD3fpnvDCzd7q/mfen92cPszeHM4XTrZWn+eOnW83pIT09PH47e//mw5PH6qOrji+L5xlG9H1963fsn/pS2bTFri8xs6VtalkjGhdH9y0+qPXtf9c9T9fhNF7X9mT7cfe75/twH177onEVVsBSVW5xyOccqfw1r+4//i47ni7991F6f2X/o6enp6ekr72XnD6b/e5P5A30lvbVpd7pEdf0NSm6Unp6ent6Xft9dJ/fu6NzSf6x03NEphLBPH5Pqef1371X//HXr2R/o6dX1Azfvd37/qzO/dinfbH/V+c2W4QeVjoeenl5db/r8zfTx04e7l7sGltTW6enp6ekr7Lf0H9u7o8SJ2flnbsv2suOJWq/6569b752ejxc9vc593cbNdRs3v/2/k2+2vzrfr1wWwHjo6enV9aYfr00fP32I+4hexB2AECI2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87OEL9znB3NmAbA93qn/+uvXQBK+34TP5l32Tf9knflyiKvorPwMQDEXzYdOP16aPH2EVr/YAAFTNoquXeoce+mpJe9PEcGrk7A8Ntx6+Y+s7rd3pWNKXHn7x8v+1KunhTvXPX7ceOuD1NsRYnwK0om4+bPrx2vTxI5Q4AwuIrkR2umX4+Mpzr0wnGt5te6R2ZnTDkW1tF0/61QMAysPrbfh4fG9gGW8hBFAJ5sOAQTgDC4i0TDx2bln3C12DNXb2H8u7fvvmL+/5aP/7S9f41QMAysPrbVi5rE9xchZQFcyHAVNwBhYQael47YUb24UQs1ZstG7JlcblTRPDPvYAgPLwegsAwWA+DJiCBSwg0hKZqZZvTi2cGrlhZjQ1+lnz2PmRxhYfewBAeXi9BYBgMB8GTMFbCIFIi2Wztw0feeD4zm+bWzs+fm2sfunp1vU+9gCA8vB6G2757xbkoldAdTEfBkzBAhYQaen4gi9uWp36/sPbPx8ar08dWvXUmWVrfewBAOXh9TbEJv+yL3/R6q1Vdfe9N1nF8QARx3wYMAULWECkZWPJT5auObCyt2ny8niyeabUB7TL9gCA8vB6G1a51SvnJKy6jZvve2+SNSygipgPA6ZgAQuAyFrxkfqUuh4Vsjo65/vWQLfoGaq0hzvVP3/demiF19uwyi1jsXQFaML3+bDpx2vTx49Q4iLuQHTZVs10ojFTk1DUwxcuswEhRM+QGOiuqIc71T9/3XpogtdbAAiGovmw6cdr08ePsGIBC4iuyUTDqbaHv1y8QlGPyrnPBhz5cwLZHu5U//x166EPXm9Dr27j5rdW1eWffpV/WXcAgVExHzb9eG36+BFi1qMvpW1b5gaWoKenp6cPoB/cKnHyhey52QPd4uh6Xv/dqP7569azP9DTq+v33rLf+U3dxs0FV3B3OCtZuTWsnq8f1Gr89PT03pk+fzN9/PQh75/4U9q2xawtMrOlb2lZIhkX9PT09PQB9C8/qfbdQ7z+u/eqf/66YX+gp1fXD9y831m06vr99JvtrxZdwPrPf72S+9bmFye0Gj89Pb0+8wfVx2vTx08f7t7atDtdorr+BiU3Sk9PT09Pr2G/766Te3d0buk/Vjru6BRC2KePRapnPkBPr66/fOna5yYVXb1yOGtYzu8X/VNG6XhWH5D4B+qW/qOPnrhTavsDN1874+xXZ37tUr7Z/uq1v2L4Qant6/b4qu5XjD9f8MWFh/+Q+/3o2mcKvnv2ht8oHU/U7q9u8wfVx2vV95f5Bn0lvdynEEptnZ6enp6eXqt+S3/pOVb+dRyi1nun5+NLTx+OXlZ541l3/3KPvdWxWuofnM72f1yn2+x+bS8n+/fnJRawdHu8gu/zV3OcP85d0wlyPKp7He5v1I7Xpo+fPsR94UXcY7OZxRMXF06NLLp6KTV2YfHExWRmSm6rAHwi+3xU3QMh4Myx5vvu3NlV1HoYitdzI7icfiWEuO+9yW2HmgMcjri559MKAxeTf9nnLF3Vbdw836/8DB4VrOa4fDEcyri/iubDoTlea3J/gfIUnoG16Oql3qGHvlrS3jQxnBo5+0PDrYfv2PpOa3c6lqzK+IAok30+qu6BcHCfY9HDRLyeowxfD1w7CevgX69bqMqdnJULKsH6lI9cFm4qPC9JT+XdX3Xz4XAcr/W5v0AZCs/ASmSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28WRVBgdEnOzzUXUPANATr+f6y51+9daquqK/RDVOwhI/rl71DF37JeasZ5Und46VLxmEEAsP/+Hxg+OPHxyf+y3n6yE7D6vs+8t82F3U7i9Cpsg1sDLx2Lll3S90DdbY2X8s7/rtm7+856P97y9dE/zgAMg+H1X3AAA98XpuBGehSgix9ncv5r54+Nle51v3vTdZnWEJIX4838pZvbq551NfTr8SP10MqwhOzvLOWazZs67hsadft+7u2rOuIfetxw+O2yfe/uPODSJE52FVeH+ZD7uL2v1FmBRZwErHay/c2C6EmLVio3VLrjQub5oYDnpcAIQQ8s9H1T0AQE+8nuus4OpX+atXzh+dNSzAo6LnJYWY7P1lPuwuavcXYVL4FkIhRCIz1fLNqYVTIzfMjKZGP2seOz/S2BL8yAAI+eej6h4AoCdez82SW7EqWLoK/l2E6+5f7rxzUPj05kH4bu575ewTb+9Z17BnXYN94u2SsXEqv7/Mh91F7f4iTIqcgRXLZm8bPvLA8Z3fNrd2fPzaWP3S063rgx8ZACH/fFTdAwD0xOu5Keaeb3X42d7qnoSV/+bB3GKWv/LfLchFrypUsIgzd00nZMq4v8yH3UXt/iJMir6FcMEXN61Off/h7Z8PjdenDq166syytYEPDIAQ8s9H1T0AQE+8nsMvfl39KmfyL/vyF62qfrUvoznXfprvj+FT3v1lPuwuavcXYVJkASsbS36ydM2Blb1Nk5fHk80zfAAzUD2yz0fVPQBAT7yeo0Lr7l9+c4//byHMrV45J2HVbdx833uTrGFJGV37TP4b5R57+vWCIH9ZJwQXca/8/jIfdhe1+4swKbKA5cha8ZH6VJBDATAf2eej6h4wmtXROd+3BrqLvH0maj2Mxuu5EfIv4l5wQfcqOvjXTwe6xbr7fT79ypFbxmLpqjy5ZZrtP//Z3O8+9vTrz/39/WBHpJYv99f3+XDIjtdVv79AGQov4m5bNdOJxkxNoiqjAZBP9vmougdCwGV2JYToGRID3ZHuYShezw1y+Nle53JXc3+jAxUnYSEAfX19fX191R5FcFzur6L5cGiO15rcX6A8hQtYk4mGU20Pf7l4RVVGAyCf7PNRdQ+Yzn125cifY0Wth7l4PTfR2t+9WN3TrwrWqvI/jrBoUIm6jZvfWlWXf/pV/mXd4VFfX591d1fBF627u3JfDNkaVnn3V8V8OEzHax3uL1A269GX0rYtcwNL0NPT09PTG9cPbpU4OUX2XPcQ9EfXMx+gp1fV771lf+4q5m+tqiva5FZ23lpV9+f/+E7peNb8d8L7WwVv7vlU9t8Le2/Z7/ymbuPmgiu4O5yVrNwaVs/XD2r1eOnW//PE885vcus1uU/im7u+s2vXrg8afqPV+E2/v7rNH1Qfr1XfX+Yb9JX08WRc2LaYtUVmtvQtLUvQ09PT09Ob2EuRvVJDCPonfqXX40VPH6Y+n5crQKkeT8+QEEOflhxG2eMRQjiLVl2/nxbi1aKZ8y0n2/zig1o9Xrr1Obt27XLWdKy7u3JrOvl27dol1O8/Ubu/UkJwvJYdj2zPfIO+kt7atDtdorr+BiU3Sk9PH1j/8pNqL3fC6wO9zv2+u07u3dG5pf9Y6bijUwhhnz5G79LzfKenV9cP3HztjKRfnfm1S/lm+7W1ni3DDyodj+r+8qVrnxP1Zvurc0+/cry1qu4///WK8/tF/5RROh56+vze9PkDx2v6KPfzfgphUVJbp6enD6B/7gO5W3m3XfJSKnr+fOjD3W/pLz3ny78uA717752e+wM9vc79j+s4m92v/eRk//68xAKWnvdXHd3GT29ir9vxl+M1Pb3HXu4MLMRmM82T36VjC+Kz6UR2JlOTGE82T8drq9XDLL7vD4NbE83/dd1NdvyicCP9fytztFf+l9z/4QEC9vLdJ53fuMz58md79ulj9C49z/dIYb4RsPxrQrlkubUtqTOwNDTyTVy4nn7lyJ2EJXUGFqLG9/mz6fOHgI/XHC/MEvr1CrkzsLDo6qXeoYe+WtLeNDGcGjn7Q8Oth+/Y+k5rdzqWrEoPs6jYH3IrVi4LVU7jvpK14xei/28//VcIsV3uzgFV4/x/S3q/ekQB840q4tP3ACnq/j2l2/FXz+M1xwuzhH69oiaYvyY0EtnpluHjK8+9Mp1oeLftkdqZ0Q1HtrVdPFmtHmbReX8oWL0CAIQY843g1W3c7H46klSmv9zpV2+tqiv6Swhx33uT2w41V3mg0J7O8+co4OdpltCvV3AGlrRMPHZuWfcLXYM1dvYfy7t+++Yv7/lo//tL11Srh1m03R/mnoEFAAgx5hvV4rI+Fb6Ts5yFKiHE2t+9mPvi4Wd7nW95+UBGQGg8f44Ifp5mCfd6BWdgSUvHay/c2C6EmLVio3VLrjQub5oYrmIPs2i7P7B6BQCRwnwD6hRc/Sp/9WruH4GStJ0/RwQ/T7OEe72CBSxpicxUyzenFk6N3DAzmhr9rHns/EhjSxV7mEXb/SH/DCwAQOgx30CQnLOu8n/j4F2E8ELb+XNE8PM0S7jXK3gLobRYNnvb8JEHju/8trm14+PXxuqXnm5dX8UeZtF2f+AMLACIFOYb1ZX/bsFwXPRqPmt/92LBotXhZ3vnfhFwoe38OSL4eZol3OsVLGBJS8cXfHHT6tT3H97++dB4ferQqqfOLFtbxR5m0XZ/4BpYABApzDeqaPIv+/IXrbgaFOBO2/lzRPDzNEu41ytYwJKWjSU/WbrmwMrepsnL48nmmVIfGKm6h1m03R9YvQKASGG+US251SvnJKy6jZvve2+SNSzAhbbz54jg52mWcK9XsIBVpqwVH6lP6dPDLBruD5yBBXNZHZ3zfWugW/QM0cv1iBTmG9WSW8aKwtJV/lXbuYI7yub7/Fm346/mx2uOF2YJ63oFF3GXY1s104nGTE1Ckx5m0Xl/YPUKhnKZ7QkheobEQDe9RI+IYL6BwBx+tte53NXc3wBeKJo/63b81fZ4zfHCLKFfr2ABS85kouFU28NfLl6hSQ+z6Lw/8CmEMJH7bM+RP+ejd+8RHcw3qq5u4+a3VtXln36Vf1n3UFr7uxc5/QplUDF/1u34q/PxmuOFWUK/XmE9+lLatmVuYAl6enpN+sGtahe/eX2g17mX2v9lz72PYH90Pc93enpV/d5b9ju/qdu4ueAK7g5nJSu3htXz9YNajb+M+5u7j2+tqiva5Fbu3lpV9+f/+E6r8dOHuzd9/sDxmj7KfTwZF7YtZm2RmS19S8sS9PT0+vS2bQshLMsqkcpztvzkQEar+0tPn99Lkb1yRAT7J36l1+NLTx+mXgjhLOh0/X5aiFeLZs63nGzziw9qNf5KXp+9XOFLt/HTh7uXwvGanl6r3tq0O12iuv4GJTdKT08fWP/yk9f+D9JzH1z7ovMGwILrWOWubDXf2wMLrn61/cdzQnl9oA+y33fXyb07Orf0Hysdd3QKIezTx+h97Hm+m9XLPl94fOnp6cPaR23+wOs5fSW96fMHuWtgSW2dnp5eq96FL1e/0u3+0pvYb+k/tndHictA5F8ngt7f3js995+o9Ty+9PT09NGcP3in5+NFX93e6P0tXvCl2GymefK7dGxBfDadyM5kahLjyebpeO18m6B371XT7f7Sm7X/5PD5g9CHc0yd7/8LzT2a0vvbwyyBPb66HR/p6enpCzLdjqeRPV5rsj/Q+/t8KZvv4y9cwFp09VLv0ENfLWlvmhhOjZz9oeHWw3dsfae1Ox1LFv0L6N171XS7v/Rm7T85uTOwWMOCDrz8fyF6dT3MEszjq9vxkZ6enn5ur9vxNJrHa332B3p/ny/l8X38hW8hTGSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28eR8A6J371XT7f7Sm7X/5LB6BQBwodvxkZ6enn6+HtWl2/5AH7L1iiLXwMrEY+eWdb/QNfjGnX2D9+5JZK7e89F+lzHRu/eq6XZ/6c3afxy+XAMLABBiuh0f6enp6aEn3fYH+jCtVxRZwErHay/c2C6EmLVio3VLrjQub5oYdvkL6N171XS7v/Rm7T8OzsACALjT7fhIT09PDz3ptj/Qh2m9osgCViIz1fLNqYVTIzfMjKZGP2seOz/S2OLyF9C796rpdn/pzdp/HJyBBQBwp9vxkZ6enh560m1/oA/TekXhRdyFELFs9rbhIw8c3/ltc2vHx6+N1S893bre5S+gd+9V0+3+0pu1/zg4AwsA4E634yM9PT099KTb/kAfpvWKIgtY6fiCL25anfr+w9s/HxqvTx1a9dSZZWtd/gJ691413e4vvVn7j4NPIQQAuNPt+EhPT08PPem2P9CHab2iyAJWNpb8ZOmaAyt7myYvjyebZ+b5gEN6j71qut1fen/7YLB6BQBwp9vxkZ6enh560m1/oHfvVfN3/EUWsK7dzIqP1KckhkVfVbrdX3p/e9U4AwtasTo65/vWQLfoGaJX28MsAT++uh0f6enp6XN0O55G/Hhd9f2B3t/nS4X8Gn/hRdxtq2Y60ZipSXjcLn116XZ/6f3tA8PqFfThcjQVQvQMiYFueoU9zBLY46vb8ZGenp6+gG7H08gerzXZH+jdM3PnD4ULWJOJhlNtD3+5eIXHv4C+unS7v/T+9oHhUwihCfejqSP/mErvbw+zBPn46nZ8pKenp8+n2/E0ysdrHfYHen+fL5XwffzWoy+lbVtiBJYl6OnpNekHt15bnH7ug2tfcRahCs6lyp1dNd8SVcEZWNt/fMXg9YE+yD63P3she24zfcn+6Hqe7yb1ss8XHl96evqw9lGbP/B6Tl9Jb/r8oSYZF8m4SMSEZXnaOj09vT69X+Y7A0u3+0sf7l6K7Dvz6Uv2uu0P9O69FB5fenr6EPdSNDz+yva6/fzpzeqlaLi/WfbpY3t3dG7pP1a67ugUQtC795t2p0uWP91Efnly310ntbq/9PQ696qfj/T+9ry+0Ue5l3294vlCT09PT09PH7W+Rgixpf/Y3h0l3gaZ/z5JevfeO6l/3eV63e4vPb3OvXflPR/p/e1123/o6XV+vdJt/PT09PT09PT0SvsaL7eZO7uid+9V0+3+0tPr3MMsuu0/9PRB9rJ0Gz89PT09PT09vbo+7vE2c9FXl273l55e5x5m0W3/oacPspel2/jp6enp6enp6RX1Nd63CAAAAAAAAASPBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGiNBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGgtnvud1dE5XzTQLXqGCr9I796rptv9pafXuYdZdNt/6OmD7GXpNn56enp6enp6ekV9TclaCNEzJAa6vW6dPgC63V96ep17mEW3/YeePshelm7jp6enp6enp6dX19eUrOfeht69V023+0tPr3MPs+i2/9DT6/x6pdv46enp6enp6emV9lbJNN+A5LnuEeyPrk/btsRNLEtI9YNbE1Lj0e3nQ08fZK/6+Ujvb8/rG32Ue9nXK54v9PT09PT09FHr5RawUNITf0rbtpi1RWa29L/cLEsk40Kqf/lJiQkrEHGqn4/0/va8viHKZF+veL4AAICo+f/elNjhpnnY9wAAAABJRU5ErkJggg==\n",
|
281 |
+
"text/plain": [
|
282 |
+
"<PIL.Image.Image image mode=RGB size=1600x224>"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
"execution_count": 47,
|
286 |
+
"metadata": {},
|
287 |
+
"output_type": "execute_result"
|
288 |
+
}
|
289 |
+
],
|
290 |
+
"source": [
|
291 |
+
"img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]\n",
|
292 |
+
"img"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "markdown",
|
297 |
+
"id": "7233c86a-eb02-48cb-8369-bb8a521bc330",
|
298 |
+
"metadata": {
|
299 |
+
"tags": []
|
300 |
+
},
|
301 |
+
"source": [
|
302 |
+
"#### Check if the model generated the correct level\n",
|
303 |
+
"##### Because of the stochastic nature of the model and the small training dataset, the model may generate levels that do not completely match the given prompt"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 50,
|
309 |
+
"id": "d3489875-e648-4c75-97f0-7ae55dc51b81",
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [
|
312 |
+
{
|
313 |
+
"data": {
|
314 |
+
"text/plain": [
|
315 |
+
"'some pipes, many enemies, some blocks, high elevation'"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
"execution_count": 50,
|
319 |
+
"metadata": {},
|
320 |
+
"output_type": "execute_result"
|
321 |
+
}
|
322 |
+
],
|
323 |
+
"source": [
|
324 |
+
"mario_lm.prompter(generated_level)[0]"
|
325 |
+
]
|
326 |
+
}
|
327 |
+
],
|
328 |
+
"metadata": {
|
329 |
+
"kernelspec": {
|
330 |
+
"display_name": "Python [conda env:py39] *",
|
331 |
+
"language": "python",
|
332 |
+
"name": "conda-env-py39-py"
|
333 |
+
},
|
334 |
+
"language_info": {
|
335 |
+
"codemirror_mode": {
|
336 |
+
"name": "ipython",
|
337 |
+
"version": 3
|
338 |
+
},
|
339 |
+
"file_extension": ".py",
|
340 |
+
"mimetype": "text/x-python",
|
341 |
+
"name": "python",
|
342 |
+
"nbconvert_exporter": "python",
|
343 |
+
"pygments_lexer": "ipython3",
|
344 |
+
"version": "3.9.0"
|
345 |
+
}
|
346 |
+
},
|
347 |
+
"nbformat": 4,
|
348 |
+
"nbformat_minor": 5
|
349 |
+
}
|
setup.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from os import path
|
5 |
+
|
6 |
+
from setuptools import find_packages
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
|
10 |
+
this_directory = path.abspath(path.dirname(__file__))
|
11 |
+
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
|
12 |
+
long_description = f.read()
|
13 |
+
|
14 |
+
|
15 |
+
setup(
|
16 |
+
name="mario-gpt",
|
17 |
+
version="0.1.0",
|
18 |
+
url="https://github.com/kragniz/cookiecutter-pypackage-minimal",
|
19 |
+
license='MIT',
|
20 |
+
|
21 |
+
author="Shyam Sudhakaran",
|
22 |
+
author_email="[email protected]",
|
23 |
+
|
24 |
+
description="Generating Mario Levels with GPT2. Code for the paper: 'MarioGPT: Open-Ended Text2Level Generation through Large Language Models', https://arxiv.org/abs/2302.05981",
|
25 |
+
|
26 |
+
long_description=long_description,
|
27 |
+
long_description_content_type="text/markdown",
|
28 |
+
|
29 |
+
packages=find_packages(exclude=('tests',)),
|
30 |
+
|
31 |
+
install_requires=[
|
32 |
+
'torch',
|
33 |
+
'transformers',
|
34 |
+
'scipy',
|
35 |
+
'tqdm'
|
36 |
+
],
|
37 |
+
|
38 |
+
classifiers=[
|
39 |
+
'Development Status :: 2 - Pre-Alpha',
|
40 |
+
'License :: OSI Approved :: MIT License',
|
41 |
+
'Programming Language :: Python :: 3',
|
42 |
+
],
|
43 |
+
)
|
static/architecture.png
ADDED
static/prompt-samples.png
ADDED