mariogpt / mario_gpt /prompter.py
multimodalart's picture
MarioGPT first attempt
850b0e4
raw
history blame
6.63 kB
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from scipy import stats
from transformers import pipeline
from mario_gpt.dataset import MarioDataset
from mario_gpt.utils import view_level
STATISTICS = {
"enemy": np.array([1.0, 3.0, 7.0]),
"pipe": np.array([0.0, 2.0, 5.0]),
"block": np.array([50.0, 75.0, 176.0]),
}
FEATURE_EXTRACTION_MODEL = "facebook/bart-base"
class Prompter:
def __init__(
self,
level_tokenizer,
prompter_model: str = FEATURE_EXTRACTION_MODEL,
use_raw_counts: bool = False,
statistics: Optional[Dict[str, Any]] = None,
):
self.prompter_model = prompter_model
self.feature_extraction = pipeline(
"feature-extraction",
model=prompter_model,
tokenizer=prompter_model,
framework="pt",
)
self.level_tokenizer = level_tokenizer
self.use_raw_counts = use_raw_counts
self.statistics = statistics
if statistics is None:
self.statistics = STATISTICS
@property
def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
thresholds = self.statistics["pipe"]
keywords = ["no", "little", "some", "many"]
return thresholds, keywords
@property
def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
thresholds = self.statistics["enemy"]
keywords = ["no", "little", "some", "many"]
return thresholds, keywords
@property
def block_thresholds(self) -> Tuple[List[int], List[str]]:
thresholds = self.statistics["block"]
keywords = ["little", "little", "some", "many"]
return thresholds, keywords
def count_pipes(self, flattened_level: str) -> int:
return flattened_level.count("<>")
def count_enemies(self, flattened_level: str) -> int:
return flattened_level.count("E") + flattened_level.count("B")
def count_blocks(self, flattened_level: str) -> int:
return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])
def _flatten_level(self, string_level: List[str]) -> str:
return "".join(string_level)
def pipe_prompt(self, flattened_level: str, level: str) -> str:
count = self.count_pipes(flattened_level)
keyword = f"{count}"
if not self.use_raw_counts:
thresholds, keywords = self.pipe_thresholds
threshold = np.digitize(count, thresholds, right=True)
keyword = keywords[threshold]
return f"{keyword} pipes", keyword
def enemy_prompt(self, flattened_level: str, level: str) -> str:
count = self.count_enemies(flattened_level)
keyword = f"{count}"
if not self.use_raw_counts:
thresholds, keywords = self.enemy_thresholds
threshold = np.digitize(count, thresholds, right=True)
keyword = keywords[threshold]
return f"{keyword} enemies", keyword
def block_prompt(self, flattened_level: str, level: str) -> str:
count = self.count_blocks(flattened_level)
keyword = f"{count}"
if not self.use_raw_counts:
thresholds, keywords = self.block_thresholds
threshold = np.digitize(count, thresholds, right=True)
keyword = keywords[threshold]
return f"{keyword} blocks", keyword
def elevation_prompt(self, flattened_level: str, level: str):
top_levels = level[:6] # elevation 8 and up
for t in top_levels:
if "X" in t or "<" in t or ">" in t:
return "high elevation", "high"
return "low elevation", "low"
def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
# Reducing along the first dimension to get a 768 dimensional array
return (
self.feature_extraction(prompt, return_tensors="pt")[0]
.mean(0)
.to(device)
.view(1, -1)
)
def dataset_statistics(self, dataset: MarioDataset):
enemy_counts = []
pipe_counts = []
block_counts = []
for i in range(len(dataset)):
level, _ = dataset[i]
str_level = self._flatten_level(view_level(level, dataset.tokenizer))
enemy_count = self.count_enemies(str_level)
pipe_count = self.count_pipes(str_level)
block_count = self.count_blocks(str_level)
enemy_counts.append(enemy_count)
pipe_counts.append(pipe_count)
block_counts.append(block_count)
d = {"enemy": {}, "pipe": {}, "block": {}}
d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
return d
def __call__(
self, level: torch.Tensor = None, sample_prompt: bool = False
) -> Union[str, torch.Tensor]:
device: torch.device = torch.device("cpu")
if not sample_prompt:
if level is None:
raise ValueError("Level must be provided if sample_prompt is not true!")
str_level = view_level(level, self.level_tokenizer)
flattened_level = self._flatten_level(str_level)
pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
block_prompt, _ = self.block_prompt(flattened_level, str_level)
elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
device = level.device
else:
str_level = None
pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
block_prompt = (
random.choice(["little", "little", "some", "many"]) + " blocks"
) # levels always have blocks
elevation_prompt = (
random.choice(["low", "high"]) + " elevation"
) # levels always have blocks
prompt_dict = {
"pipe": pipe_prompt,
"enemy": enemy_prompt,
"block": block_prompt,
"elevation_prompt": elevation_prompt,
}
prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
hidden = self.output_hidden(prompt, device=device)
return prompt, hidden, prompt_dict, str_level